Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ jobs:
name: Test with python ${{ matrix.python-version }} / ${{ matrix.os-version }}
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13-dev"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
os-version: ["ubuntu-latest", "windows-latest"]
exclude:
- os-version: windows-latest
include:
- os-version: windows-latest
python-version: 3.12
python-version: 3.13

runs-on: ${{ matrix.os-version }}

Expand Down Expand Up @@ -93,7 +93,7 @@ jobs:
uses: ./.github/actions/setup
with:
type: doc
python-version: "3.12"
python-version: "3.13"
poetry-version: ${{ env.POETRY_VERSION }}

- name: Test
Expand Down
1,005 changes: 495 additions & 510 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ include = [
]

[tool.poetry.dependencies]
python = ">=3.8"
python = ">=3.9"
Werkzeug = ">= 2.0.0"


Expand Down Expand Up @@ -137,5 +137,5 @@ lint.ignore = [
"UP032",
]
line-length = 120
target-version = "py38"
target-version = "py39"
exclude = ["doc", "example*.py", "tests/examples/*.py"]
4 changes: 2 additions & 2 deletions pytest_httpserver/blocking_httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from queue import Queue
from typing import TYPE_CHECKING
from typing import Any
from typing import Mapping
from typing import Pattern

from pytest_httpserver.httpserver import METHOD_ALL
from pytest_httpserver.httpserver import UNDEFINED
Expand All @@ -16,6 +14,8 @@
from pytest_httpserver.httpserver import URIPattern

if TYPE_CHECKING:
from collections.abc import Mapping
from re import Pattern
from ssl import SSLContext

from werkzeug import Request
Expand Down
97 changes: 57 additions & 40 deletions pytest_httpserver/httpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,19 @@
import time
import urllib.parse
from collections import defaultdict
from collections.abc import Iterable
from collections.abc import Mapping
from collections.abc import MutableMapping
from contextlib import contextmanager
from contextlib import suppress
from copy import copy
from enum import Enum
from re import Pattern
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import ClassVar
from typing import Iterable
from typing import Mapping
from typing import MutableMapping
from typing import Optional
from typing import Pattern
from typing import Tuple
from typing import Union

import werkzeug.http
Expand All @@ -34,13 +33,16 @@

if TYPE_CHECKING:
from ssl import SSLContext
from types import TracebackType

from werkzeug.serving import BaseWSGIServer

URI_DEFAULT = ""
METHOD_ALL = "__ALL"

HEADERS_T = Union[
Mapping[str, Union[str, Iterable[str]]],
Iterable[Tuple[str, str]],
Iterable[tuple[str, str]],
]

HVMATCHER_T = Callable[[str, Optional[str], str], bool]
Expand Down Expand Up @@ -113,11 +115,13 @@ def complete(self, result: bool): # noqa: FBT001

@property
def result(self) -> bool:
return self._result
return bool(self._result)

@property
def elapsed_time(self) -> float:
"""Elapsed time in seconds"""
if self._stop is None:
raise TypeError("unsupported operand type(s) for -: 'NoneType' and 'float'")
return self._stop - self._start


Expand All @@ -139,7 +143,7 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
func = getattr(Authorization, "from_header", None)
if func is None: # Werkzeug < 2.3.0
func = werkzeug.http.parse_authorization_header # type: ignore[attr-defined]
return func(actual) == func(expected)
return func(actual) == func(expected) # type: ignore

@staticmethod
def default_header_value_matcher(actual: str | None, expected: str) -> bool:
Expand Down Expand Up @@ -174,7 +178,7 @@ def match(self, request_query_string: bytes) -> bool:
return values[0] == values[1]

@abc.abstractmethod
def get_comparing_values(self, request_query_string: bytes) -> tuple:
def get_comparing_values(self, request_query_string: bytes) -> tuple[Any, Any]:
pass


Expand All @@ -195,10 +199,10 @@ def __init__(self, query_string: bytes | str):

self.query_string = query_string

def get_comparing_values(self, request_query_string: bytes) -> tuple:
def get_comparing_values(self, request_query_string: bytes) -> tuple[bytes, bytes]:
if isinstance(self.query_string, str):
query_string = self.query_string.encode()
elif isinstance(self.query_string, bytes):
elif isinstance(self.query_string, bytes): # type: ignore
query_string = self.query_string
else:
raise TypeError("query_string must be a string, or a bytes-like object")
Expand All @@ -211,7 +215,7 @@ class MappingQueryMatcher(QueryMatcher):
Matches a query string to a dictionary or MultiDict specified
"""

def __init__(self, query_dict: Mapping | MultiDict):
def __init__(self, query_dict: Mapping[str, str] | MultiDict[str, str]):
"""
:param query_dict: if dictionary (Mapping) is specified, it will be used as a
key-value mapping where both key and value should be string. If there are multiple
Expand All @@ -221,7 +225,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
"""
self.query_dict = query_dict

def get_comparing_values(self, request_query_string: bytes) -> tuple:
def get_comparing_values(self, request_query_string: bytes) -> tuple[Mapping[str, str], Mapping[str, str]]:
query = MultiDict(urllib.parse.parse_qsl(request_query_string.decode("utf-8")))
if isinstance(self.query_dict, MultiDict):
return (query, self.query_dict)
Expand All @@ -241,14 +245,14 @@ def __init__(self, result: bool): # noqa: FBT001
"""
self.result = result

def get_comparing_values(self, request_query_string): # noqa: ARG002
def get_comparing_values(self, request_query_string: bytes): # noqa: ARG002
if self.result:
return (True, True)
else:
return (True, False)


def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping) -> QueryMatcher:
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping[str, str]) -> QueryMatcher:
if isinstance(query_string, QueryMatcher):
return query_string

Expand Down Expand Up @@ -312,7 +316,7 @@ def __init__(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
json: Any = UNDEFINED,
):
Expand Down Expand Up @@ -410,7 +414,7 @@ def match_json(self, request: Request) -> bool:

return json_received == self.json

def difference(self, request: Request) -> list[tuple]:
def difference(self, request: Request) -> list[tuple[str, str, str | URIPattern]]:
"""
Calculates the difference between the matcher and the request.

Expand All @@ -422,7 +426,7 @@ def difference(self, request: Request) -> list[tuple]:
matches the fields set in the matcher object.
"""

retval: list[tuple] = []
retval: list[tuple[str, Any, Any]] = []

if not self.match_uri(request):
retval.append(("uri", request.path, self.uri))
Expand All @@ -433,8 +437,8 @@ def difference(self, request: Request) -> list[tuple]:
if not self.query_matcher.match(request.query_string):
retval.append(("query_string", request.query_string, self.query_string))

request_headers = {}
expected_headers = {}
request_headers: dict[str, str | None] = {}
expected_headers: dict[str, str] = {}
for key, value in self.headers.items():
if not self.header_value_matcher(key, request.headers.get(key), value):
request_headers[key] = request.headers.get(key)
Expand Down Expand Up @@ -467,7 +471,7 @@ class RequestHandlerBase(abc.ABC):

def respond_with_json(
self,
response_json,
response_json: Any,
status: int = 200,
headers: Mapping[str, str] | None = None,
content_type: str = "application/json",
Expand Down Expand Up @@ -578,7 +582,7 @@ def __repr__(self) -> str:
return retval


class RequestHandlerList(list):
class RequestHandlerList(list[RequestHandler]):
"""
Represents a list of :py:class:`RequestHandler` objects.

Expand Down Expand Up @@ -638,9 +642,9 @@ def __init__(
"""
self.host = host
self.port = port
self.server = None
self.server_thread = None
self.assertions: list[str] = []
self.server: BaseWSGIServer | None = None
self.server_thread: threading.Thread | None = None
self.assertions: list[str | AssertionError] = []
self.handler_errors: list[Exception] = []
self.log: list[tuple[Request, Response]] = []
self.ssl_context = ssl_context
Expand Down Expand Up @@ -727,7 +731,7 @@ def thread_target(self):

This should not be called directly, but can be overridden to tailor it to your needs.
"""

assert self.server is not None
self.server.serve_forever()

def is_running(self) -> bool:
Expand All @@ -736,7 +740,7 @@ def is_running(self) -> bool:
"""
return bool(self.server)

def start(self):
def start(self) -> None:
"""
Start the server in a thread.

Expand All @@ -755,9 +759,16 @@ def start(self):
if self.is_running():
raise HTTPServerError("Server is already running")

app = Request.application(self.application)

self.server = make_server(
self.host, self.port, self.application, ssl_context=self.ssl_context, threaded=self.threaded
self.host,
self.port,
app,
ssl_context=self.ssl_context,
threaded=self.threaded,
)

self.port = self.server.port # Update port (needed if `port` was set to 0)
self.server_thread = threading.Thread(target=self.thread_target)
self.server_thread.start()
Expand All @@ -772,14 +783,16 @@ def stop(self):
Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
will be raised.
"""
assert self.server is not None
assert self.server_thread is not None
if not self.is_running():
raise HTTPServerError("Server is not running")
self.server.shutdown()
self.server_thread.join()
self.server = None
self.server_thread = None

def add_assertion(self, obj):
def add_assertion(self, obj: str | AssertionError):
"""
Add a new assertion

Expand Down Expand Up @@ -848,8 +861,7 @@ def dispatch(self, request: Request) -> Response:
:return: the response object what the handler responded, or a response which contains the error
"""

@Request.application # type: ignore
def application(self, request: Request):
def application(self, request: Request) -> Response:
"""
Entry point of werkzeug.

Expand All @@ -875,7 +887,12 @@ def __enter__(self):
self.start()
return self

def __exit__(self, *args, **kwargs):
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
):
"""
Provide the context API

Expand All @@ -886,7 +903,7 @@ def __exit__(self, *args, **kwargs):
self.stop()

@staticmethod
def format_host(host):
def format_host(host: str):
"""
Formats a hostname so it can be used in a URL.
Notably, this adds brackets around IPV6 addresses when
Expand Down Expand Up @@ -929,8 +946,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute

def __init__(
self,
host=DEFAULT_LISTEN_HOST,
port=DEFAULT_LISTEN_PORT,
host: str = DEFAULT_LISTEN_HOST,
port: int = DEFAULT_LISTEN_PORT,
ssl_context: SSLContext | None = None,
default_waiting_settings: WaitingSettings | None = None,
*,
Expand Down Expand Up @@ -995,7 +1012,7 @@ def expect_request(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
handler_type: HandlerType = HandlerType.PERMANENT,
json: Any = UNDEFINED,
Expand Down Expand Up @@ -1078,7 +1095,7 @@ def expect_oneshot_request(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
json: Any = UNDEFINED,
) -> RequestHandler:
Expand Down Expand Up @@ -1133,7 +1150,7 @@ def expect_ordered_request(
data: str | bytes | None = None,
data_encoding: str = "utf-8",
headers: Mapping[str, str] | None = None,
query_string: None | QueryMatcher | str | bytes | Mapping = None,
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
header_value_matcher: HVMATCHER_T | None = None,
json: Any = UNDEFINED,
) -> RequestHandler:
Expand Down Expand Up @@ -1191,13 +1208,13 @@ def format_matchers(self) -> str:
This method is primarily used when reporting errors.
"""

def format_handlers(handlers):
def format_handlers(handlers: list[RequestHandler]):
if handlers:
return [" {!r}".format(handler.matcher) for handler in handlers]
else:
return [" none"]

lines = []
lines: list[str] = []
lines.append("Ordered matchers:")
lines.extend(format_handlers(self.ordered_handlers))
lines.append("")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
deprecations:
- |
Python versions earlier than 3.9 have been deprecated in order to make the
code more type safe. Python 3.8 has reached EOL on 2024-10-07.
Loading