Skip to content

Commit e75439d

Browse files
committed
improve type hints, apply type related fixes
1 parent 3961aa7 commit e75439d

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

pytest_httpserver/httpserver.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable
1919
from typing import ClassVar
2020
from typing import Iterable
21+
from typing import List
2122
from typing import Mapping
2223
from typing import MutableMapping
2324
from typing import Optional
@@ -34,6 +35,9 @@
3435

3536
if TYPE_CHECKING:
3637
from ssl import SSLContext
38+
from types import TracebackType
39+
40+
from werkzeug.serving import BaseWSGIServer
3741

3842
URI_DEFAULT = ""
3943
METHOD_ALL = "__ALL"
@@ -113,11 +117,13 @@ def complete(self, result: bool): # noqa: FBT001
113117

114118
@property
115119
def result(self) -> bool:
116-
return self._result
120+
return bool(self._result)
117121

118122
@property
119123
def elapsed_time(self) -> float:
120124
"""Elapsed time in seconds"""
125+
if self._stop is None:
126+
raise TypeError("unsupported operand type(s) for -: 'NoneType' and 'float'")
121127
return self._stop - self._start
122128

123129

@@ -139,7 +145,8 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
139145
func = getattr(Authorization, "from_header", None)
140146
if func is None: # Werkzeug < 2.3.0
141147
func = werkzeug.http.parse_authorization_header # type: ignore[attr-defined]
142-
return func(actual) == func(expected)
148+
149+
return func(actual) == func(expected) # type: ignore
143150

144151
@staticmethod
145152
def default_header_value_matcher(actual: str | None, expected: str) -> bool:
@@ -174,7 +181,7 @@ def match(self, request_query_string: bytes) -> bool:
174181
return values[0] == values[1]
175182

176183
@abc.abstractmethod
177-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
184+
def get_comparing_values(self, request_query_string: bytes) -> tuple[Any, Any]:
178185
pass
179186

180187

@@ -195,10 +202,10 @@ def __init__(self, query_string: bytes | str):
195202

196203
self.query_string = query_string
197204

198-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
205+
def get_comparing_values(self, request_query_string: bytes) -> tuple[bytes, bytes]:
199206
if isinstance(self.query_string, str):
200207
query_string = self.query_string.encode()
201-
elif isinstance(self.query_string, bytes):
208+
elif isinstance(self.query_string, bytes): # type: ignore
202209
query_string = self.query_string
203210
else:
204211
raise TypeError("query_string must be a string, or a bytes-like object")
@@ -211,7 +218,7 @@ class MappingQueryMatcher(QueryMatcher):
211218
Matches a query string to a dictionary or MultiDict specified
212219
"""
213220

214-
def __init__(self, query_dict: Mapping | MultiDict):
221+
def __init__(self, query_dict: Mapping[str, str] | MultiDict[str, str]):
215222
"""
216223
:param query_dict: if dictionary (Mapping) is specified, it will be used as a
217224
key-value mapping where both key and value should be string. If there are multiple
@@ -221,7 +228,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
221228
"""
222229
self.query_dict = query_dict
223230

224-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
231+
def get_comparing_values(self, request_query_string: bytes) -> tuple[Mapping[str, str], Mapping[str, str]]:
225232
query = MultiDict(urllib.parse.parse_qsl(request_query_string.decode("utf-8")))
226233
if isinstance(self.query_dict, MultiDict):
227234
return (query, self.query_dict)
@@ -241,14 +248,14 @@ def __init__(self, result: bool): # noqa: FBT001
241248
"""
242249
self.result = result
243250

244-
def get_comparing_values(self, request_query_string): # noqa: ARG002
251+
def get_comparing_values(self, request_query_string: bytes): # noqa: ARG002
245252
if self.result:
246253
return (True, True)
247254
else:
248255
return (True, False)
249256

250257

251-
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping) -> QueryMatcher:
258+
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping[str, str]) -> QueryMatcher:
252259
if isinstance(query_string, QueryMatcher):
253260
return query_string
254261

@@ -312,7 +319,7 @@ def __init__(
312319
data: str | bytes | None = None,
313320
data_encoding: str = "utf-8",
314321
headers: Mapping[str, str] | None = None,
315-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
322+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
316323
header_value_matcher: HVMATCHER_T | None = None,
317324
json: Any = UNDEFINED,
318325
):
@@ -410,7 +417,7 @@ def match_json(self, request: Request) -> bool:
410417

411418
return json_received == self.json
412419

413-
def difference(self, request: Request) -> list[tuple]:
420+
def difference(self, request: Request) -> list[tuple[str, str, str | URIPattern]]:
414421
"""
415422
Calculates the difference between the matcher and the request.
416423
@@ -422,7 +429,7 @@ def difference(self, request: Request) -> list[tuple]:
422429
matches the fields set in the matcher object.
423430
"""
424431

425-
retval: list[tuple] = []
432+
retval: list[tuple[str, Any, Any]] = []
426433

427434
if not self.match_uri(request):
428435
retval.append(("uri", request.path, self.uri))
@@ -433,8 +440,8 @@ def difference(self, request: Request) -> list[tuple]:
433440
if not self.query_matcher.match(request.query_string):
434441
retval.append(("query_string", request.query_string, self.query_string))
435442

436-
request_headers = {}
437-
expected_headers = {}
443+
request_headers: dict[str, str | None] = {}
444+
expected_headers: dict[str, str] = {}
438445
for key, value in self.headers.items():
439446
if not self.header_value_matcher(key, request.headers.get(key), value):
440447
request_headers[key] = request.headers.get(key)
@@ -467,7 +474,7 @@ class RequestHandlerBase(abc.ABC):
467474

468475
def respond_with_json(
469476
self,
470-
response_json,
477+
response_json: Any,
471478
status: int = 200,
472479
headers: Mapping[str, str] | None = None,
473480
content_type: str = "application/json",
@@ -578,7 +585,7 @@ def __repr__(self) -> str:
578585
return retval
579586

580587

581-
class RequestHandlerList(list):
588+
class RequestHandlerList(List[RequestHandler]):
582589
"""
583590
Represents a list of :py:class:`RequestHandler` objects.
584591
@@ -638,9 +645,9 @@ def __init__(
638645
"""
639646
self.host = host
640647
self.port = port
641-
self.server = None
642-
self.server_thread = None
643-
self.assertions: list[str] = []
648+
self.server: BaseWSGIServer | None = None
649+
self.server_thread: threading.Thread | None = None
650+
self.assertions: list[str | AssertionError] = []
644651
self.handler_errors: list[Exception] = []
645652
self.log: list[tuple[Request, Response]] = []
646653
self.ssl_context = ssl_context
@@ -727,7 +734,7 @@ def thread_target(self):
727734
728735
This should not be called directly, but can be overridden to tailor it to your needs.
729736
"""
730-
737+
assert self.server is not None
731738
self.server.serve_forever()
732739

733740
def is_running(self) -> bool:
@@ -736,7 +743,7 @@ def is_running(self) -> bool:
736743
"""
737744
return bool(self.server)
738745

739-
def start(self):
746+
def start(self) -> None:
740747
"""
741748
Start the server in a thread.
742749
@@ -755,9 +762,16 @@ def start(self):
755762
if self.is_running():
756763
raise HTTPServerError("Server is already running")
757764

765+
app = Request.application(self.application)
766+
758767
self.server = make_server(
759-
self.host, self.port, self.application, ssl_context=self.ssl_context, threaded=self.threaded
768+
self.host,
769+
self.port,
770+
app,
771+
ssl_context=self.ssl_context,
772+
threaded=self.threaded,
760773
)
774+
761775
self.port = self.server.port # Update port (needed if `port` was set to 0)
762776
self.server_thread = threading.Thread(target=self.thread_target)
763777
self.server_thread.start()
@@ -772,14 +786,16 @@ def stop(self):
772786
Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
773787
will be raised.
774788
"""
789+
assert self.server is not None
790+
assert self.server_thread is not None
775791
if not self.is_running():
776792
raise HTTPServerError("Server is not running")
777793
self.server.shutdown()
778794
self.server_thread.join()
779795
self.server = None
780796
self.server_thread = None
781797

782-
def add_assertion(self, obj):
798+
def add_assertion(self, obj: str | AssertionError):
783799
"""
784800
Add a new assertion
785801
@@ -848,8 +864,7 @@ def dispatch(self, request: Request) -> Response:
848864
:return: the response object what the handler responded, or a response which contains the error
849865
"""
850866

851-
@Request.application # type: ignore
852-
def application(self, request: Request):
867+
def application(self, request: Request) -> Response:
853868
"""
854869
Entry point of werkzeug.
855870
@@ -875,7 +890,12 @@ def __enter__(self):
875890
self.start()
876891
return self
877892

878-
def __exit__(self, *args, **kwargs):
893+
def __exit__(
894+
self,
895+
exc_type: type[BaseException] | None,
896+
exc_value: BaseException | None,
897+
traceback: TracebackType | None,
898+
):
879899
"""
880900
Provide the context API
881901
@@ -886,7 +906,7 @@ def __exit__(self, *args, **kwargs):
886906
self.stop()
887907

888908
@staticmethod
889-
def format_host(host):
909+
def format_host(host: str):
890910
"""
891911
Formats a hostname so it can be used in a URL.
892912
Notably, this adds brackets around IPV6 addresses when
@@ -929,8 +949,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute
929949

930950
def __init__(
931951
self,
932-
host=DEFAULT_LISTEN_HOST,
933-
port=DEFAULT_LISTEN_PORT,
952+
host: str = DEFAULT_LISTEN_HOST,
953+
port: int = DEFAULT_LISTEN_PORT,
934954
ssl_context: SSLContext | None = None,
935955
default_waiting_settings: WaitingSettings | None = None,
936956
*,
@@ -979,7 +999,7 @@ def expect_request(
979999
data: str | bytes | None = None,
9801000
data_encoding: str = "utf-8",
9811001
headers: Mapping[str, str] | None = None,
982-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1002+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
9831003
header_value_matcher: HVMATCHER_T | None = None,
9841004
handler_type: HandlerType = HandlerType.PERMANENT,
9851005
json: Any = UNDEFINED,
@@ -1062,7 +1082,7 @@ def expect_oneshot_request(
10621082
data: str | bytes | None = None,
10631083
data_encoding: str = "utf-8",
10641084
headers: Mapping[str, str] | None = None,
1065-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1085+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
10661086
header_value_matcher: HVMATCHER_T | None = None,
10671087
json: Any = UNDEFINED,
10681088
) -> RequestHandler:
@@ -1117,7 +1137,7 @@ def expect_ordered_request(
11171137
data: str | bytes | None = None,
11181138
data_encoding: str = "utf-8",
11191139
headers: Mapping[str, str] | None = None,
1120-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1140+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
11211141
header_value_matcher: HVMATCHER_T | None = None,
11221142
json: Any = UNDEFINED,
11231143
) -> RequestHandler:
@@ -1175,13 +1195,13 @@ def format_matchers(self) -> str:
11751195
This method is primarily used when reporting errors.
11761196
"""
11771197

1178-
def format_handlers(handlers):
1198+
def format_handlers(handlers: list[RequestHandler]):
11791199
if handlers:
11801200
return [" {!r}".format(handler.matcher) for handler in handlers]
11811201
else:
11821202
return [" none"]
11831203

1184-
lines = []
1204+
lines: list[str] = []
11851205
lines.append("Ordered matchers:")
11861206
lines.extend(format_handlers(self.ordered_handlers))
11871207
lines.append("")

tests/test_release.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
NAME = "pytest-httpserver"
2828
NAME_UNDERSCORE = NAME.replace("-", "_")
29-
PY_MAX_VERSION = (3, 12)
29+
PY_MAX_VERSION = (3, 13)
3030

3131

3232
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)