Skip to content

Commit 5720ea7

Browse files
committed
Define an alternative can_handle logic by passing a callable.
1 parent d193c96 commit 5720ea7

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

mocket/mocks/mockhttp.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import time
33
from functools import cached_property
44
from http.server import BaseHTTPRequestHandler
5+
from typing import Callable, Optional
56
from urllib.parse import parse_qs, unquote, urlsplit
67

78
from h11 import SERVER, Connection, Data
@@ -82,9 +83,7 @@ def __init__(self, body="", status=200, headers=None):
8283
self.status = status
8384

8485
self.set_base_headers()
85-
86-
if headers is not None:
87-
self.set_extra_headers(headers)
86+
self.set_extra_headers(headers)
8887

8988
self.data = self.get_protocol_data() + self.body
9089

@@ -142,9 +141,19 @@ class Entry(MocketEntry):
142141
request_cls = Request
143142
response_cls = Response
144143

145-
default_config = {"match_querystring": True}
144+
default_config = {"match_querystring": True, "can_handle_fun": None}
145+
_can_handle_fun: Optional[Callable] = None
146+
147+
def __init__(
148+
self,
149+
uri,
150+
method,
151+
responses,
152+
match_querystring: bool = True,
153+
can_handle_fun: Optional[Callable] = None,
154+
):
155+
self._can_handle_fun = can_handle_fun if can_handle_fun else self._can_handle
146156

147-
def __init__(self, uri, method, responses, match_querystring: bool = True):
148157
uri = urlsplit(uri)
149158

150159
port = uri.port
@@ -177,6 +186,15 @@ def collect(self, data):
177186

178187
return consume_response
179188

189+
def _can_handle(self, method, path, query):
190+
can_handle = path == self.path and method == self.method
191+
if self._match_querystring:
192+
kw = dict(keep_blank_values=True)
193+
can_handle = can_handle and parse_qs(query, **kw) == parse_qs(
194+
self.query, **kw
195+
)
196+
return can_handle
197+
180198
def can_handle(self, data):
181199
r"""
182200
>>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b'<html/>'),))
@@ -192,13 +210,10 @@ def can_handle(self, data):
192210
except ValueError:
193211
return self is getattr(Mocket, "_last_entry", None)
194212

195-
uri = urlsplit(path)
196-
can_handle = uri.path == self.path and method == self.method
197-
if self._match_querystring:
198-
kw = dict(keep_blank_values=True)
199-
can_handle = can_handle and parse_qs(uri.query, **kw) == parse_qs(
200-
self.query, **kw
201-
)
213+
_request = urlsplit(path)
214+
215+
can_handle = self._can_handle_fun(method, _request.path, _request.query)
216+
202217
if can_handle:
203218
Mocket._last_entry = self
204219
return can_handle
@@ -249,8 +264,27 @@ def single_register(
249264
headers=None,
250265
exception=None,
251266
match_querystring=True,
267+
can_handle_fun=None,
252268
**config,
253269
):
270+
"""
271+
A helper method to register a single Response for a given URI and method.
272+
Instead of passing a list of Response objects, you can just pass the response
273+
parameters directly.
274+
275+
Args:
276+
method (str): The HTTP method (e.g., 'GET', 'POST').
277+
uri (str): The URI to register the response for.
278+
body (str, optional): The body of the response. Defaults to an empty string.
279+
status (int, optional): The HTTP status code. Defaults to 200.
280+
headers (dict, optional): A dictionary of headers to include in the response. Defaults to None.
281+
exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None.
282+
match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True.
283+
can_handle_fun (Callable, optional): A custom function to determine if the entry can handle a request.
284+
Defaults to None. If None, the default matching logic is used. The function should accept three parameters:
285+
method (str), path (str), and querystring params (dict), and return a boolean.
286+
**config: Additional configuration options.
287+
"""
254288
response = (
255289
exception
256290
if exception
@@ -262,5 +296,6 @@ def single_register(
262296
uri,
263297
response,
264298
match_querystring=match_querystring,
299+
can_handle_fun=can_handle_fun,
265300
**config,
266301
)

tests/test_http.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,24 @@ def test_mocket_with_no_path(self):
455455
response = urlopen("http://httpbin.local/")
456456
self.assertEqual(response.code, 202)
457457
self.assertEqual(Mocket._entries[("httpbin.local", 80)][0].path, "/")
458+
459+
@mocketize
460+
def test_can_handle(self):
461+
Entry.single_register(
462+
Entry.GET,
463+
"http://testme.org/",
464+
body=json.dumps({"message": "Gotcha!"}),
465+
can_handle_fun=lambda m, p, q: "a" in q,
466+
)
467+
Entry.single_register(
468+
Entry.GET,
469+
"http://testme.org/foobar",
470+
body=json.dumps({"message": "Missed!"}),
471+
match_querystring=False,
472+
)
473+
response = requests.get("http://testme.org/foobar?a=1")
474+
self.assertEqual(response.status_code, 200)
475+
self.assertEqual(response.json(), {"message": "Gotcha!"})
476+
response = requests.get("http://testme.org/foobar?b=2")
477+
self.assertEqual(response.status_code, 200)
478+
self.assertEqual(response.json(), {"message": "Missed!"})

0 commit comments

Comments
 (0)