22import time
33from functools import cached_property
44from http .server import BaseHTTPRequestHandler
5+ from typing import Callable , Optional
56from urllib .parse import parse_qs , unquote , urlsplit
67
78from h11 import SERVER , Connection , Data
@@ -82,9 +83,7 @@ def __init__(self, body="", status=200, headers=None):
8283 self .status = status
8384
8485 self .set_base_headers ()
85-
86- if headers is not None :
87- self .set_extra_headers (headers )
86+ self .set_extra_headers (headers )
8887
8988 self .data = self .get_protocol_data () + self .body
9089
@@ -142,9 +141,19 @@ class Entry(MocketEntry):
142141 request_cls = Request
143142 response_cls = Response
144143
145- default_config = {"match_querystring" : True }
144+ default_config = {"match_querystring" : True , "can_handle_fun" : None }
145+ _can_handle_fun : Optional [Callable ] = None
146+
147+ def __init__ (
148+ self ,
149+ uri ,
150+ method ,
151+ responses ,
152+ match_querystring : bool = True ,
153+ can_handle_fun : Optional [Callable ] = None ,
154+ ):
155+ self ._can_handle_fun = can_handle_fun if can_handle_fun else self ._can_handle
146156
147- def __init__ (self , uri , method , responses , match_querystring : bool = True ):
148157 uri = urlsplit (uri )
149158
150159 port = uri .port
@@ -177,6 +186,15 @@ def collect(self, data):
177186
178187 return consume_response
179188
189+ def _can_handle (self , method , path , query ):
190+ can_handle = path == self .path and method == self .method
191+ if self ._match_querystring :
192+ kw = dict (keep_blank_values = True )
193+ can_handle = can_handle and parse_qs (query , ** kw ) == parse_qs (
194+ self .query , ** kw
195+ )
196+ return can_handle
197+
180198 def can_handle (self , data ):
181199 r"""
182200 >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b'<html/>'),))
@@ -192,13 +210,10 @@ def can_handle(self, data):
192210 except ValueError :
193211 return self is getattr (Mocket , "_last_entry" , None )
194212
195- uri = urlsplit (path )
196- can_handle = uri .path == self .path and method == self .method
197- if self ._match_querystring :
198- kw = dict (keep_blank_values = True )
199- can_handle = can_handle and parse_qs (uri .query , ** kw ) == parse_qs (
200- self .query , ** kw
201- )
213+ _request = urlsplit (path )
214+
215+ can_handle = self ._can_handle_fun (method , _request .path , _request .query )
216+
202217 if can_handle :
203218 Mocket ._last_entry = self
204219 return can_handle
@@ -249,8 +264,27 @@ def single_register(
249264 headers = None ,
250265 exception = None ,
251266 match_querystring = True ,
267+ can_handle_fun = None ,
252268 ** config ,
253269 ):
270+ """
271+ A helper method to register a single Response for a given URI and method.
272+ Instead of passing a list of Response objects, you can just pass the response
273+ parameters directly.
274+
275+ Args:
276+ method (str): The HTTP method (e.g., 'GET', 'POST').
277+ uri (str): The URI to register the response for.
278+ body (str, optional): The body of the response. Defaults to an empty string.
279+ status (int, optional): The HTTP status code. Defaults to 200.
280+ headers (dict, optional): A dictionary of headers to include in the response. Defaults to None.
281+ exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None.
282+ match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True.
283+ can_handle_fun (Callable, optional): A custom function to determine if the entry can handle a request.
284+ Defaults to None. If None, the default matching logic is used. The function should accept three parameters:
285+ method (str), path (str), and querystring params (dict), and return a boolean.
286+ **config: Additional configuration options.
287+ """
254288 response = (
255289 exception
256290 if exception
@@ -262,5 +296,6 @@ def single_register(
262296 uri ,
263297 response ,
264298 match_querystring = match_querystring ,
299+ can_handle_fun = can_handle_fun ,
265300 ** config ,
266301 )
0 commit comments