Skip to content

Commit c62f4e0

Browse files
committed
improve type hints, apply type related fixes
1 parent 75d69e5 commit c62f4e0

File tree

1 file changed

+53
-34
lines changed

1 file changed

+53
-34
lines changed

pytest_httpserver/httpserver.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable
1919
from typing import ClassVar
2020
from typing import Iterable
21+
from typing import List
2122
from typing import Mapping
2223
from typing import MutableMapping
2324
from typing import Optional
@@ -34,6 +35,9 @@
3435

3536
if TYPE_CHECKING:
3637
from ssl import SSLContext
38+
from types import TracebackType
39+
40+
from werkzeug.serving import BaseWSGIServer
3741

3842
URI_DEFAULT = ""
3943
METHOD_ALL = "__ALL"
@@ -113,11 +117,13 @@ def complete(self, result: bool): # noqa: FBT001
113117

114118
@property
115119
def result(self) -> bool:
116-
return self._result
120+
return bool(self._result)
117121

118122
@property
119123
def elapsed_time(self) -> float:
120124
"""Elapsed time in seconds"""
125+
if self._stop is None:
126+
raise TypeError("unsupported operand type(s) for -: 'NoneType' and 'float'")
121127
return self._stop - self._start
122128

123129

@@ -139,7 +145,7 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
139145
func = getattr(Authorization, "from_header", None)
140146
if func is None: # Werkzeug < 2.3.0
141147
func = werkzeug.http.parse_authorization_header # type: ignore[attr-defined]
142-
return func(actual) == func(expected)
148+
return func(actual) == func(expected) # type: ignore
143149

144150
@staticmethod
145151
def default_header_value_matcher(actual: str | None, expected: str) -> bool:
@@ -174,7 +180,7 @@ def match(self, request_query_string: bytes) -> bool:
174180
return values[0] == values[1]
175181

176182
@abc.abstractmethod
177-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
183+
def get_comparing_values(self, request_query_string: bytes) -> tuple[Any, Any]:
178184
pass
179185

180186

@@ -195,10 +201,10 @@ def __init__(self, query_string: bytes | str):
195201

196202
self.query_string = query_string
197203

198-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
204+
def get_comparing_values(self, request_query_string: bytes) -> tuple[bytes, bytes]:
199205
if isinstance(self.query_string, str):
200206
query_string = self.query_string.encode()
201-
elif isinstance(self.query_string, bytes):
207+
elif isinstance(self.query_string, bytes): # type: ignore
202208
query_string = self.query_string
203209
else:
204210
raise TypeError("query_string must be a string, or a bytes-like object")
@@ -211,7 +217,7 @@ class MappingQueryMatcher(QueryMatcher):
211217
Matches a query string to a dictionary or MultiDict specified
212218
"""
213219

214-
def __init__(self, query_dict: Mapping | MultiDict):
220+
def __init__(self, query_dict: Mapping[str, str] | MultiDict[str, str]):
215221
"""
216222
:param query_dict: if dictionary (Mapping) is specified, it will be used as a
217223
key-value mapping where both key and value should be string. If there are multiple
@@ -221,7 +227,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
221227
"""
222228
self.query_dict = query_dict
223229

224-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
230+
def get_comparing_values(self, request_query_string: bytes) -> tuple[Mapping[str, str], Mapping[str, str]]:
225231
query = MultiDict(urllib.parse.parse_qsl(request_query_string.decode("utf-8")))
226232
if isinstance(self.query_dict, MultiDict):
227233
return (query, self.query_dict)
@@ -241,14 +247,14 @@ def __init__(self, result: bool): # noqa: FBT001
241247
"""
242248
self.result = result
243249

244-
def get_comparing_values(self, request_query_string): # noqa: ARG002
250+
def get_comparing_values(self, request_query_string: bytes): # noqa: ARG002
245251
if self.result:
246252
return (True, True)
247253
else:
248254
return (True, False)
249255

250256

251-
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping) -> QueryMatcher:
257+
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping[str, str]) -> QueryMatcher:
252258
if isinstance(query_string, QueryMatcher):
253259
return query_string
254260

@@ -312,7 +318,7 @@ def __init__(
312318
data: str | bytes | None = None,
313319
data_encoding: str = "utf-8",
314320
headers: Mapping[str, str] | None = None,
315-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
321+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
316322
header_value_matcher: HVMATCHER_T | None = None,
317323
json: Any = UNDEFINED,
318324
):
@@ -410,7 +416,7 @@ def match_json(self, request: Request) -> bool:
410416

411417
return json_received == self.json
412418

413-
def difference(self, request: Request) -> list[tuple]:
419+
def difference(self, request: Request) -> list[tuple[str, str, str | URIPattern]]:
414420
"""
415421
Calculates the difference between the matcher and the request.
416422
@@ -422,7 +428,7 @@ def difference(self, request: Request) -> list[tuple]:
422428
matches the fields set in the matcher object.
423429
"""
424430

425-
retval: list[tuple] = []
431+
retval: list[tuple[str, Any, Any]] = []
426432

427433
if not self.match_uri(request):
428434
retval.append(("uri", request.path, self.uri))
@@ -433,8 +439,8 @@ def difference(self, request: Request) -> list[tuple]:
433439
if not self.query_matcher.match(request.query_string):
434440
retval.append(("query_string", request.query_string, self.query_string))
435441

436-
request_headers = {}
437-
expected_headers = {}
442+
request_headers: dict[str, str | None] = {}
443+
expected_headers: dict[str, str] = {}
438444
for key, value in self.headers.items():
439445
if not self.header_value_matcher(key, request.headers.get(key), value):
440446
request_headers[key] = request.headers.get(key)
@@ -467,7 +473,7 @@ class RequestHandlerBase(abc.ABC):
467473

468474
def respond_with_json(
469475
self,
470-
response_json,
476+
response_json: Any,
471477
status: int = 200,
472478
headers: Mapping[str, str] | None = None,
473479
content_type: str = "application/json",
@@ -578,7 +584,7 @@ def __repr__(self) -> str:
578584
return retval
579585

580586

581-
class RequestHandlerList(list):
587+
class RequestHandlerList(List[RequestHandler]):
582588
"""
583589
Represents a list of :py:class:`RequestHandler` objects.
584590
@@ -638,9 +644,9 @@ def __init__(
638644
"""
639645
self.host = host
640646
self.port = port
641-
self.server = None
642-
self.server_thread = None
643-
self.assertions: list[str] = []
647+
self.server: BaseWSGIServer | None = None
648+
self.server_thread: threading.Thread | None = None
649+
self.assertions: list[str | AssertionError] = []
644650
self.handler_errors: list[Exception] = []
645651
self.log: list[tuple[Request, Response]] = []
646652
self.ssl_context = ssl_context
@@ -727,7 +733,7 @@ def thread_target(self):
727733
728734
This should not be called directly, but can be overridden to tailor it to your needs.
729735
"""
730-
736+
assert self.server is not None
731737
self.server.serve_forever()
732738

733739
def is_running(self) -> bool:
@@ -736,7 +742,7 @@ def is_running(self) -> bool:
736742
"""
737743
return bool(self.server)
738744

739-
def start(self):
745+
def start(self) -> None:
740746
"""
741747
Start the server in a thread.
742748
@@ -755,9 +761,16 @@ def start(self):
755761
if self.is_running():
756762
raise HTTPServerError("Server is already running")
757763

764+
app = Request.application(self.application)
765+
758766
self.server = make_server(
759-
self.host, self.port, self.application, ssl_context=self.ssl_context, threaded=self.threaded
767+
self.host,
768+
self.port,
769+
app,
770+
ssl_context=self.ssl_context,
771+
threaded=self.threaded,
760772
)
773+
761774
self.port = self.server.port # Update port (needed if `port` was set to 0)
762775
self.server_thread = threading.Thread(target=self.thread_target)
763776
self.server_thread.start()
@@ -772,14 +785,16 @@ def stop(self):
772785
Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
773786
will be raised.
774787
"""
788+
assert self.server is not None
789+
assert self.server_thread is not None
775790
if not self.is_running():
776791
raise HTTPServerError("Server is not running")
777792
self.server.shutdown()
778793
self.server_thread.join()
779794
self.server = None
780795
self.server_thread = None
781796

782-
def add_assertion(self, obj):
797+
def add_assertion(self, obj: str | AssertionError):
783798
"""
784799
Add a new assertion
785800
@@ -848,8 +863,7 @@ def dispatch(self, request: Request) -> Response:
848863
:return: the response object what the handler responded, or a response which contains the error
849864
"""
850865

851-
@Request.application # type: ignore
852-
def application(self, request: Request):
866+
def application(self, request: Request) -> Response:
853867
"""
854868
Entry point of werkzeug.
855869
@@ -875,7 +889,12 @@ def __enter__(self):
875889
self.start()
876890
return self
877891

878-
def __exit__(self, *args, **kwargs):
892+
def __exit__(
893+
self,
894+
exc_type: type[BaseException] | None,
895+
exc_value: BaseException | None,
896+
traceback: TracebackType | None,
897+
):
879898
"""
880899
Provide the context API
881900
@@ -886,7 +905,7 @@ def __exit__(self, *args, **kwargs):
886905
self.stop()
887906

888907
@staticmethod
889-
def format_host(host):
908+
def format_host(host: str):
890909
"""
891910
Formats a hostname so it can be used in a URL.
892911
Notably, this adds brackets around IPV6 addresses when
@@ -929,8 +948,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute
929948

930949
def __init__(
931950
self,
932-
host=DEFAULT_LISTEN_HOST,
933-
port=DEFAULT_LISTEN_PORT,
951+
host: str = DEFAULT_LISTEN_HOST,
952+
port: int = DEFAULT_LISTEN_PORT,
934953
ssl_context: SSLContext | None = None,
935954
default_waiting_settings: WaitingSettings | None = None,
936955
*,
@@ -979,7 +998,7 @@ def expect_request(
979998
data: str | bytes | None = None,
980999
data_encoding: str = "utf-8",
9811000
headers: Mapping[str, str] | None = None,
982-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1001+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
9831002
header_value_matcher: HVMATCHER_T | None = None,
9841003
handler_type: HandlerType = HandlerType.PERMANENT,
9851004
json: Any = UNDEFINED,
@@ -1062,7 +1081,7 @@ def expect_oneshot_request(
10621081
data: str | bytes | None = None,
10631082
data_encoding: str = "utf-8",
10641083
headers: Mapping[str, str] | None = None,
1065-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1084+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
10661085
header_value_matcher: HVMATCHER_T | None = None,
10671086
json: Any = UNDEFINED,
10681087
) -> RequestHandler:
@@ -1117,7 +1136,7 @@ def expect_ordered_request(
11171136
data: str | bytes | None = None,
11181137
data_encoding: str = "utf-8",
11191138
headers: Mapping[str, str] | None = None,
1120-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1139+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
11211140
header_value_matcher: HVMATCHER_T | None = None,
11221141
json: Any = UNDEFINED,
11231142
) -> RequestHandler:
@@ -1175,13 +1194,13 @@ def format_matchers(self) -> str:
11751194
This method is primarily used when reporting errors.
11761195
"""
11771196

1178-
def format_handlers(handlers):
1197+
def format_handlers(handlers: list[RequestHandler]):
11791198
if handlers:
11801199
return [" {!r}".format(handler.matcher) for handler in handlers]
11811200
else:
11821201
return [" none"]
11831202

1184-
lines = []
1203+
lines: list[str] = []
11851204
lines.append("Ordered matchers:")
11861205
lines.extend(format_handlers(self.ordered_handlers))
11871206
lines.append("")

0 commit comments

Comments
 (0)