22import time
33from functools import cached_property
44from http .server import BaseHTTPRequestHandler
5+ from typing import Callable , Optional
56from urllib .parse import parse_qs , unquote , urlsplit
67
78from h11 import SERVER , Connection , Data
@@ -82,9 +83,7 @@ def __init__(self, body="", status=200, headers=None):
8283 self .status = status
8384
8485 self .set_base_headers ()
85-
86- if headers is not None :
87- self .set_extra_headers (headers )
86+ self .set_extra_headers (headers )
8887
8988 self .data = self .get_protocol_data () + self .body
9089
@@ -142,9 +141,19 @@ class Entry(MocketEntry):
142141 request_cls = Request
143142 response_cls = Response
144143
145- default_config = {"match_querystring" : True }
144+ default_config = {"match_querystring" : True , "can_handle_fun" : None }
145+ _can_handle_fun : Optional [Callable ] = None
146+
147+ def __init__ (
148+ self ,
149+ uri ,
150+ method ,
151+ responses ,
152+ match_querystring : bool = True ,
153+ can_handle_fun : Optional [Callable ] = None ,
154+ ):
155+ self ._can_handle_fun = can_handle_fun if can_handle_fun else self ._can_handle
146156
147- def __init__ (self , uri , method , responses , match_querystring : bool = True ):
148157 uri = urlsplit (uri )
149158
150159 port = uri .port
@@ -177,6 +186,18 @@ def collect(self, data):
177186
178187 return consume_response
179188
189+ def _can_handle (self , path : str , qs_dict : dict ) -> bool :
190+ """
191+ The default can_handle function, which checks if the path match,
192+ and if match_querystring is True, also checks if the querystring matches.
193+ """
194+ can_handle = path == self .path
195+ if self ._match_querystring :
196+ can_handle = can_handle and qs_dict == parse_qs (
197+ self .query , keep_blank_values = True
198+ )
199+ return can_handle
200+
180201 def can_handle (self , data ):
181202 r"""
182203 >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b'<html/>'),))
@@ -192,13 +213,12 @@ def can_handle(self, data):
192213 except ValueError :
193214 return self is getattr (Mocket , "_last_entry" , None )
194215
195- uri = urlsplit (path )
196- can_handle = uri .path == self .path and method == self .method
197- if self ._match_querystring :
198- kw = dict (keep_blank_values = True )
199- can_handle = can_handle and parse_qs (uri .query , ** kw ) == parse_qs (
200- self .query , ** kw
201- )
216+ _request = urlsplit (path )
217+
218+ can_handle = method == self .method and self ._can_handle_fun (
219+ _request .path , parse_qs (_request .query , keep_blank_values = True )
220+ )
221+
202222 if can_handle :
203223 Mocket ._last_entry = self
204224 return can_handle
@@ -249,8 +269,27 @@ def single_register(
249269 headers = None ,
250270 exception = None ,
251271 match_querystring = True ,
272+ can_handle_fun = None ,
252273 ** config ,
253274 ):
275+ """
276+ A helper method to register a single Response for a given URI and method.
277+ Instead of passing a list of Response objects, you can just pass the response
278+ parameters directly.
279+
280+ Args:
281+ method (str): The HTTP method (e.g., 'GET', 'POST').
282+ uri (str): The URI to register the response for.
283+ body (str, optional): The body of the response. Defaults to an empty string.
284+ status (int, optional): The HTTP status code. Defaults to 200.
285+ headers (dict, optional): A dictionary of headers to include in the response. Defaults to None.
286+ exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None.
287+ match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True.
288+ can_handle_fun (Callable, optional): A custom function to determine if the Entry can handle a request.
289+ Defaults to None. If None, the default matching logic is used. The function should accept two parameters:
290+ path (str), and querystring params (dict), and return a boolean. Method is matched before the function call.
291+ **config: Additional configuration options.
292+ """
254293 response = (
255294 exception
256295 if exception
@@ -262,5 +301,6 @@ def single_register(
262301 uri ,
263302 response ,
264303 match_querystring = match_querystring ,
304+ can_handle_fun = can_handle_fun ,
265305 ** config ,
266306 )
0 commit comments