1818from typing import Callable
1919from typing import ClassVar
2020from typing import Iterable
21+ from typing import List
2122from typing import Mapping
2223from typing import MutableMapping
2324from typing import Optional
3435
3536if TYPE_CHECKING :
3637 from ssl import SSLContext
38+ from types import TracebackType
39+
40+ from werkzeug .serving import BaseWSGIServer
3741
3842URI_DEFAULT = ""
3943METHOD_ALL = "__ALL"
@@ -113,11 +117,13 @@ def complete(self, result: bool): # noqa: FBT001
113117
114118 @property
115119 def result (self ) -> bool :
116- return self ._result
120+ return bool ( self ._result )
117121
118122 @property
119123 def elapsed_time (self ) -> float :
120124 """Elapsed time in seconds"""
125+ if self ._stop is None :
126+ raise TypeError ("unsupported operand type(s) for -: 'NoneType' and 'float'" )
121127 return self ._stop - self ._start
122128
123129
@@ -139,7 +145,7 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
139145 func = getattr (Authorization , "from_header" , None )
140146 if func is None : # Werkzeug < 2.3.0
141147 func = werkzeug .http .parse_authorization_header # type: ignore[attr-defined]
142- return func (actual ) == func (expected )
148+ return func (actual ) == func (expected ) # type: ignore
143149
144150 @staticmethod
145151 def default_header_value_matcher (actual : str | None , expected : str ) -> bool :
@@ -174,7 +180,7 @@ def match(self, request_query_string: bytes) -> bool:
174180 return values [0 ] == values [1 ]
175181
176182 @abc .abstractmethod
177- def get_comparing_values (self , request_query_string : bytes ) -> tuple :
183+ def get_comparing_values (self , request_query_string : bytes ) -> tuple [ Any , Any ] :
178184 pass
179185
180186
@@ -195,10 +201,10 @@ def __init__(self, query_string: bytes | str):
195201
196202 self .query_string = query_string
197203
198- def get_comparing_values (self , request_query_string : bytes ) -> tuple :
204+ def get_comparing_values (self , request_query_string : bytes ) -> tuple [ bytes , bytes ] :
199205 if isinstance (self .query_string , str ):
200206 query_string = self .query_string .encode ()
201- elif isinstance (self .query_string , bytes ):
207+ elif isinstance (self .query_string , bytes ): # type: ignore
202208 query_string = self .query_string
203209 else :
204210 raise TypeError ("query_string must be a string, or a bytes-like object" )
@@ -211,7 +217,7 @@ class MappingQueryMatcher(QueryMatcher):
211217 Matches a query string to a dictionary or MultiDict specified
212218 """
213219
214- def __init__ (self , query_dict : Mapping | MultiDict ):
220+ def __init__ (self , query_dict : Mapping [ str , str ] | MultiDict [ str , str ] ):
215221 """
216222 :param query_dict: if dictionary (Mapping) is specified, it will be used as a
217223 key-value mapping where both key and value should be string. If there are multiple
@@ -221,7 +227,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
221227 """
222228 self .query_dict = query_dict
223229
224- def get_comparing_values (self , request_query_string : bytes ) -> tuple :
230+ def get_comparing_values (self , request_query_string : bytes ) -> tuple [ Mapping [ str , str ], Mapping [ str , str ]] :
225231 query = MultiDict (urllib .parse .parse_qsl (request_query_string .decode ("utf-8" )))
226232 if isinstance (self .query_dict , MultiDict ):
227233 return (query , self .query_dict )
@@ -241,14 +247,14 @@ def __init__(self, result: bool): # noqa: FBT001
241247 """
242248 self .result = result
243249
244- def get_comparing_values (self , request_query_string ): # noqa: ARG002
250+ def get_comparing_values (self , request_query_string : bytes ): # noqa: ARG002
245251 if self .result :
246252 return (True , True )
247253 else :
248254 return (True , False )
249255
250256
251- def _create_query_matcher (query_string : None | QueryMatcher | str | bytes | Mapping ) -> QueryMatcher :
257+ def _create_query_matcher (query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] ) -> QueryMatcher :
252258 if isinstance (query_string , QueryMatcher ):
253259 return query_string
254260
@@ -312,7 +318,7 @@ def __init__(
312318 data : str | bytes | None = None ,
313319 data_encoding : str = "utf-8" ,
314320 headers : Mapping [str , str ] | None = None ,
315- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
321+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
316322 header_value_matcher : HVMATCHER_T | None = None ,
317323 json : Any = UNDEFINED ,
318324 ):
@@ -410,7 +416,7 @@ def match_json(self, request: Request) -> bool:
410416
411417 return json_received == self .json
412418
413- def difference (self , request : Request ) -> list [tuple ]:
419+ def difference (self , request : Request ) -> list [tuple [ str , str , str | URIPattern ] ]:
414420 """
415421 Calculates the difference between the matcher and the request.
416422
@@ -422,7 +428,7 @@ def difference(self, request: Request) -> list[tuple]:
422428 matches the fields set in the matcher object.
423429 """
424430
425- retval : list [tuple ] = []
431+ retval : list [tuple [ str , Any , Any ] ] = []
426432
427433 if not self .match_uri (request ):
428434 retval .append (("uri" , request .path , self .uri ))
@@ -433,8 +439,8 @@ def difference(self, request: Request) -> list[tuple]:
433439 if not self .query_matcher .match (request .query_string ):
434440 retval .append (("query_string" , request .query_string , self .query_string ))
435441
436- request_headers = {}
437- expected_headers = {}
442+ request_headers : dict [ str , str | None ] = {}
443+ expected_headers : dict [ str , str ] = {}
438444 for key , value in self .headers .items ():
439445 if not self .header_value_matcher (key , request .headers .get (key ), value ):
440446 request_headers [key ] = request .headers .get (key )
@@ -467,7 +473,7 @@ class RequestHandlerBase(abc.ABC):
467473
468474 def respond_with_json (
469475 self ,
470- response_json ,
476+ response_json : Any ,
471477 status : int = 200 ,
472478 headers : Mapping [str , str ] | None = None ,
473479 content_type : str = "application/json" ,
@@ -578,7 +584,7 @@ def __repr__(self) -> str:
578584 return retval
579585
580586
581- class RequestHandlerList (list ):
587+ class RequestHandlerList (List [ RequestHandler ] ):
582588 """
583589 Represents a list of :py:class:`RequestHandler` objects.
584590
@@ -638,9 +644,9 @@ def __init__(
638644 """
639645 self .host = host
640646 self .port = port
641- self .server = None
642- self .server_thread = None
643- self .assertions : list [str ] = []
647+ self .server : BaseWSGIServer | None = None
648+ self .server_thread : threading . Thread | None = None
649+ self .assertions : list [str | AssertionError ] = []
644650 self .handler_errors : list [Exception ] = []
645651 self .log : list [tuple [Request , Response ]] = []
646652 self .ssl_context = ssl_context
@@ -727,7 +733,7 @@ def thread_target(self):
727733
728734 This should not be called directly, but can be overridden to tailor it to your needs.
729735 """
730-
736+ assert self . server is not None
731737 self .server .serve_forever ()
732738
733739 def is_running (self ) -> bool :
@@ -736,7 +742,7 @@ def is_running(self) -> bool:
736742 """
737743 return bool (self .server )
738744
739- def start (self ):
745+ def start (self ) -> None :
740746 """
741747 Start the server in a thread.
742748
@@ -755,9 +761,16 @@ def start(self):
755761 if self .is_running ():
756762 raise HTTPServerError ("Server is already running" )
757763
764+ app = Request .application (self .application )
765+
758766 self .server = make_server (
759- self .host , self .port , self .application , ssl_context = self .ssl_context , threaded = self .threaded
767+ self .host ,
768+ self .port ,
769+ app ,
770+ ssl_context = self .ssl_context ,
771+ threaded = self .threaded ,
760772 )
773+
761774 self .port = self .server .port # Update port (needed if `port` was set to 0)
762775 self .server_thread = threading .Thread (target = self .thread_target )
763776 self .server_thread .start ()
@@ -772,14 +785,16 @@ def stop(self):
772785 Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
773786 will be raised.
774787 """
788+ assert self .server is not None
789+ assert self .server_thread is not None
775790 if not self .is_running ():
776791 raise HTTPServerError ("Server is not running" )
777792 self .server .shutdown ()
778793 self .server_thread .join ()
779794 self .server = None
780795 self .server_thread = None
781796
782- def add_assertion (self , obj ):
797+ def add_assertion (self , obj : str | AssertionError ):
783798 """
784799 Add a new assertion
785800
@@ -848,8 +863,7 @@ def dispatch(self, request: Request) -> Response:
848863 :return: the response object what the handler responded, or a response which contains the error
849864 """
850865
851- @Request .application # type: ignore
852- def application (self , request : Request ):
866+ def application (self , request : Request ) -> Response :
853867 """
854868 Entry point of werkzeug.
855869
@@ -875,7 +889,12 @@ def __enter__(self):
875889 self .start ()
876890 return self
877891
878- def __exit__ (self , * args , ** kwargs ):
892+ def __exit__ (
893+ self ,
894+ exc_type : type [BaseException ] | None ,
895+ exc_value : BaseException | None ,
896+ traceback : TracebackType | None ,
897+ ):
879898 """
880899 Provide the context API
881900
@@ -886,7 +905,7 @@ def __exit__(self, *args, **kwargs):
886905 self .stop ()
887906
888907 @staticmethod
889- def format_host (host ):
908+ def format_host (host : str ):
890909 """
891910 Formats a hostname so it can be used in a URL.
892911 Notably, this adds brackets around IPV6 addresses when
@@ -929,8 +948,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute
929948
930949 def __init__ (
931950 self ,
932- host = DEFAULT_LISTEN_HOST ,
933- port = DEFAULT_LISTEN_PORT ,
951+ host : str = DEFAULT_LISTEN_HOST ,
952+ port : int = DEFAULT_LISTEN_PORT ,
934953 ssl_context : SSLContext | None = None ,
935954 default_waiting_settings : WaitingSettings | None = None ,
936955 * ,
@@ -979,7 +998,7 @@ def expect_request(
979998 data : str | bytes | None = None ,
980999 data_encoding : str = "utf-8" ,
9811000 headers : Mapping [str , str ] | None = None ,
982- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
1001+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
9831002 header_value_matcher : HVMATCHER_T | None = None ,
9841003 handler_type : HandlerType = HandlerType .PERMANENT ,
9851004 json : Any = UNDEFINED ,
@@ -1062,7 +1081,7 @@ def expect_oneshot_request(
10621081 data : str | bytes | None = None ,
10631082 data_encoding : str = "utf-8" ,
10641083 headers : Mapping [str , str ] | None = None ,
1065- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
1084+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
10661085 header_value_matcher : HVMATCHER_T | None = None ,
10671086 json : Any = UNDEFINED ,
10681087 ) -> RequestHandler :
@@ -1117,7 +1136,7 @@ def expect_ordered_request(
11171136 data : str | bytes | None = None ,
11181137 data_encoding : str = "utf-8" ,
11191138 headers : Mapping [str , str ] | None = None ,
1120- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
1139+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
11211140 header_value_matcher : HVMATCHER_T | None = None ,
11221141 json : Any = UNDEFINED ,
11231142 ) -> RequestHandler :
@@ -1175,13 +1194,13 @@ def format_matchers(self) -> str:
11751194 This method is primarily used when reporting errors.
11761195 """
11771196
1178- def format_handlers (handlers ):
1197+ def format_handlers (handlers : list [ RequestHandler ] ):
11791198 if handlers :
11801199 return [" {!r}" .format (handler .matcher ) for handler in handlers ]
11811200 else :
11821201 return [" none" ]
11831202
1184- lines = []
1203+ lines : list [ str ] = []
11851204 lines .append ("Ordered matchers:" )
11861205 lines .extend (format_handlers (self .ordered_handlers ))
11871206 lines .append ("" )
0 commit comments