1818from typing import Callable
1919from typing import ClassVar
2020from typing import Iterable
21+ from typing import List
2122from typing import Mapping
2223from typing import MutableMapping
2324from typing import Optional
3435
3536if TYPE_CHECKING :
3637 from ssl import SSLContext
38+ from types import TracebackType
39+
40+ from werkzeug .serving import BaseWSGIServer
3741
3842URI_DEFAULT = ""
3943METHOD_ALL = "__ALL"
@@ -113,11 +117,13 @@ def complete(self, result: bool): # noqa: FBT001
113117
114118 @property
115119 def result (self ) -> bool :
116- return self ._result
120+ return bool ( self ._result )
117121
118122 @property
119123 def elapsed_time (self ) -> float :
120124 """Elapsed time in seconds"""
125+ if self ._stop is None :
126+ raise TypeError ("unsupported operand type(s) for -: 'NoneType' and 'float'" )
121127 return self ._stop - self ._start
122128
123129
@@ -139,7 +145,8 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
139145 func = getattr (Authorization , "from_header" , None )
140146 if func is None : # Werkzeug < 2.3.0
141147 func = werkzeug .http .parse_authorization_header # type: ignore[attr-defined]
142- return func (actual ) == func (expected )
148+
149+ return func (actual ) == func (expected ) # type: ignore
143150
144151 @staticmethod
145152 def default_header_value_matcher (actual : str | None , expected : str ) -> bool :
@@ -174,7 +181,7 @@ def match(self, request_query_string: bytes) -> bool:
174181 return values [0 ] == values [1 ]
175182
176183 @abc .abstractmethod
177- def get_comparing_values (self , request_query_string : bytes ) -> tuple :
184+ def get_comparing_values (self , request_query_string : bytes ) -> tuple [ Any , Any ] :
178185 pass
179186
180187
@@ -195,10 +202,10 @@ def __init__(self, query_string: bytes | str):
195202
196203 self .query_string = query_string
197204
198- def get_comparing_values (self , request_query_string : bytes ) -> tuple :
205+ def get_comparing_values (self , request_query_string : bytes ) -> tuple [ bytes , bytes ] :
199206 if isinstance (self .query_string , str ):
200207 query_string = self .query_string .encode ()
201- elif isinstance (self .query_string , bytes ):
208+ elif isinstance (self .query_string , bytes ): # type: ignore
202209 query_string = self .query_string
203210 else :
204211 raise TypeError ("query_string must be a string, or a bytes-like object" )
@@ -211,7 +218,7 @@ class MappingQueryMatcher(QueryMatcher):
211218 Matches a query string to a dictionary or MultiDict specified
212219 """
213220
214- def __init__ (self , query_dict : Mapping | MultiDict ):
221+ def __init__ (self , query_dict : Mapping [ str , str ] | MultiDict [ str , str ] ):
215222 """
216223 :param query_dict: if dictionary (Mapping) is specified, it will be used as a
217224 key-value mapping where both key and value should be string. If there are multiple
@@ -221,7 +228,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
221228 """
222229 self .query_dict = query_dict
223230
224- def get_comparing_values (self , request_query_string : bytes ) -> tuple :
231+ def get_comparing_values (self , request_query_string : bytes ) -> tuple [ Mapping [ str , str ], Mapping [ str , str ]] :
225232 query = MultiDict (urllib .parse .parse_qsl (request_query_string .decode ("utf-8" )))
226233 if isinstance (self .query_dict , MultiDict ):
227234 return (query , self .query_dict )
@@ -241,14 +248,14 @@ def __init__(self, result: bool): # noqa: FBT001
241248 """
242249 self .result = result
243250
244- def get_comparing_values (self , request_query_string ): # noqa: ARG002
251+ def get_comparing_values (self , request_query_string : bytes ): # noqa: ARG002
245252 if self .result :
246253 return (True , True )
247254 else :
248255 return (True , False )
249256
250257
251- def _create_query_matcher (query_string : None | QueryMatcher | str | bytes | Mapping ) -> QueryMatcher :
258+ def _create_query_matcher (query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] ) -> QueryMatcher :
252259 if isinstance (query_string , QueryMatcher ):
253260 return query_string
254261
@@ -312,7 +319,7 @@ def __init__(
312319 data : str | bytes | None = None ,
313320 data_encoding : str = "utf-8" ,
314321 headers : Mapping [str , str ] | None = None ,
315- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
322+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
316323 header_value_matcher : HVMATCHER_T | None = None ,
317324 json : Any = UNDEFINED ,
318325 ):
@@ -410,7 +417,7 @@ def match_json(self, request: Request) -> bool:
410417
411418 return json_received == self .json
412419
413- def difference (self , request : Request ) -> list [tuple ]:
420+ def difference (self , request : Request ) -> list [tuple [ str , str , str | URIPattern ] ]:
414421 """
415422 Calculates the difference between the matcher and the request.
416423
@@ -422,7 +429,7 @@ def difference(self, request: Request) -> list[tuple]:
422429 matches the fields set in the matcher object.
423430 """
424431
425- retval : list [tuple ] = []
432+ retval : list [tuple [ str , Any , Any ] ] = []
426433
427434 if not self .match_uri (request ):
428435 retval .append (("uri" , request .path , self .uri ))
@@ -433,8 +440,8 @@ def difference(self, request: Request) -> list[tuple]:
433440 if not self .query_matcher .match (request .query_string ):
434441 retval .append (("query_string" , request .query_string , self .query_string ))
435442
436- request_headers = {}
437- expected_headers = {}
443+ request_headers : dict [ str , str | None ] = {}
444+ expected_headers : dict [ str , str ] = {}
438445 for key , value in self .headers .items ():
439446 if not self .header_value_matcher (key , request .headers .get (key ), value ):
440447 request_headers [key ] = request .headers .get (key )
@@ -467,7 +474,7 @@ class RequestHandlerBase(abc.ABC):
467474
468475 def respond_with_json (
469476 self ,
470- response_json ,
477+ response_json : Any ,
471478 status : int = 200 ,
472479 headers : Mapping [str , str ] | None = None ,
473480 content_type : str = "application/json" ,
@@ -578,7 +585,7 @@ def __repr__(self) -> str:
578585 return retval
579586
580587
581- class RequestHandlerList (list ):
588+ class RequestHandlerList (List [ RequestHandler ] ):
582589 """
583590 Represents a list of :py:class:`RequestHandler` objects.
584591
@@ -638,9 +645,9 @@ def __init__(
638645 """
639646 self .host = host
640647 self .port = port
641- self .server = None
642- self .server_thread = None
643- self .assertions : list [str ] = []
648+ self .server : BaseWSGIServer | None = None
649+ self .server_thread : threading . Thread | None = None
650+ self .assertions : list [str | AssertionError ] = []
644651 self .handler_errors : list [Exception ] = []
645652 self .log : list [tuple [Request , Response ]] = []
646653 self .ssl_context = ssl_context
@@ -727,7 +734,7 @@ def thread_target(self):
727734
728735 This should not be called directly, but can be overridden to tailor it to your needs.
729736 """
730-
737+ assert self . server is not None
731738 self .server .serve_forever ()
732739
733740 def is_running (self ) -> bool :
@@ -736,7 +743,7 @@ def is_running(self) -> bool:
736743 """
737744 return bool (self .server )
738745
739- def start (self ):
746+ def start (self ) -> None :
740747 """
741748 Start the server in a thread.
742749
@@ -755,9 +762,16 @@ def start(self):
755762 if self .is_running ():
756763 raise HTTPServerError ("Server is already running" )
757764
765+ app = Request .application (self .application )
766+
758767 self .server = make_server (
759- self .host , self .port , self .application , ssl_context = self .ssl_context , threaded = self .threaded
768+ self .host ,
769+ self .port ,
770+ app ,
771+ ssl_context = self .ssl_context ,
772+ threaded = self .threaded ,
760773 )
774+
761775 self .port = self .server .port # Update port (needed if `port` was set to 0)
762776 self .server_thread = threading .Thread (target = self .thread_target )
763777 self .server_thread .start ()
@@ -772,14 +786,16 @@ def stop(self):
772786 Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
773787 will be raised.
774788 """
789+ assert self .server is not None
790+ assert self .server_thread is not None
775791 if not self .is_running ():
776792 raise HTTPServerError ("Server is not running" )
777793 self .server .shutdown ()
778794 self .server_thread .join ()
779795 self .server = None
780796 self .server_thread = None
781797
782- def add_assertion (self , obj ):
798+ def add_assertion (self , obj : str | AssertionError ):
783799 """
784800 Add a new assertion
785801
@@ -848,8 +864,7 @@ def dispatch(self, request: Request) -> Response:
848864 :return: the response object what the handler responded, or a response which contains the error
849865 """
850866
851- @Request .application # type: ignore
852- def application (self , request : Request ):
867+ def application (self , request : Request ) -> Response :
853868 """
854869 Entry point of werkzeug.
855870
@@ -875,7 +890,12 @@ def __enter__(self):
875890 self .start ()
876891 return self
877892
878- def __exit__ (self , * args , ** kwargs ):
893+ def __exit__ (
894+ self ,
895+ exc_type : type [BaseException ] | None ,
896+ exc_value : BaseException | None ,
897+ traceback : TracebackType | None ,
898+ ):
879899 """
880900 Provide the context API
881901
@@ -886,7 +906,7 @@ def __exit__(self, *args, **kwargs):
886906 self .stop ()
887907
888908 @staticmethod
889- def format_host (host ):
909+ def format_host (host : str ):
890910 """
891911 Formats a hostname so it can be used in a URL.
892912 Notably, this adds brackets around IPV6 addresses when
@@ -929,8 +949,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute
929949
930950 def __init__ (
931951 self ,
932- host = DEFAULT_LISTEN_HOST ,
933- port = DEFAULT_LISTEN_PORT ,
952+ host : str = DEFAULT_LISTEN_HOST ,
953+ port : int = DEFAULT_LISTEN_PORT ,
934954 ssl_context : SSLContext | None = None ,
935955 default_waiting_settings : WaitingSettings | None = None ,
936956 * ,
@@ -979,7 +999,7 @@ def expect_request(
979999 data : str | bytes | None = None ,
9801000 data_encoding : str = "utf-8" ,
9811001 headers : Mapping [str , str ] | None = None ,
982- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
1002+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
9831003 header_value_matcher : HVMATCHER_T | None = None ,
9841004 handler_type : HandlerType = HandlerType .PERMANENT ,
9851005 json : Any = UNDEFINED ,
@@ -1062,7 +1082,7 @@ def expect_oneshot_request(
10621082 data : str | bytes | None = None ,
10631083 data_encoding : str = "utf-8" ,
10641084 headers : Mapping [str , str ] | None = None ,
1065- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
1085+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
10661086 header_value_matcher : HVMATCHER_T | None = None ,
10671087 json : Any = UNDEFINED ,
10681088 ) -> RequestHandler :
@@ -1117,7 +1137,7 @@ def expect_ordered_request(
11171137 data : str | bytes | None = None ,
11181138 data_encoding : str = "utf-8" ,
11191139 headers : Mapping [str , str ] | None = None ,
1120- query_string : None | QueryMatcher | str | bytes | Mapping = None ,
1140+ query_string : None | QueryMatcher | str | bytes | Mapping [ str , str ] = None ,
11211141 header_value_matcher : HVMATCHER_T | None = None ,
11221142 json : Any = UNDEFINED ,
11231143 ) -> RequestHandler :
@@ -1175,13 +1195,13 @@ def format_matchers(self) -> str:
11751195 This method is primarily used when reporting errors.
11761196 """
11771197
1178- def format_handlers (handlers ):
1198+ def format_handlers (handlers : list [ RequestHandler ] ):
11791199 if handlers :
11801200 return [" {!r}" .format (handler .matcher ) for handler in handlers ]
11811201 else :
11821202 return [" none" ]
11831203
1184- lines = []
1204+ lines : list [ str ] = []
11851205 lines .append ("Ordered matchers:" )
11861206 lines .extend (format_handlers (self .ordered_handlers ))
11871207 lines .append ("" )
0 commit comments