Skip to content

Commit b4e3484

Browse files
committed
refactor: migrate mocket.compat.entry to use mocket.core.entry
1 parent dc0e2d9 commit b4e3484

File tree

4 files changed

+51
-82
lines changed

4 files changed

+51
-82
lines changed

mocket/compat/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from mocket.compat.entry import MocketEntry
1+
from mocket.compat.entry import MocketEntry, Response
22
from mocket.core.ssl.context import MocketSSLContext as FakeSSLContext
33

44
__all__ = [
55
"FakeSSLContext",
66
"MocketEntry",
7+
"Response",
78
]

mocket/compat/entry.py

Lines changed: 33 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,36 @@
1-
import collections.abc
2-
3-
from mocket.core.compat import encode_to_bytes
4-
from mocket.core.mocket import Mocket
5-
6-
7-
class MocketEntry:
8-
class Response(bytes):
9-
@property
10-
def data(self):
11-
return self
12-
13-
response_index = 0
14-
request_cls = bytes
15-
response_cls = Response
16-
responses = None
17-
_served = None
18-
19-
def __init__(self, location, responses):
20-
self._served = False
21-
self.location = location
22-
23-
if not isinstance(responses, collections.abc.Iterable):
1+
from __future__ import annotations
2+
3+
from mocket.bytes import MocketBytesEntry, MocketBytesResponse
4+
from mocket.core.types import Address
5+
6+
7+
class Response(MocketBytesResponse):
8+
def __init__(self, data: bytes | str | bool) -> None:
9+
if isinstance(data, str):
10+
data = data.encode()
11+
elif isinstance(data, bool):
12+
data = bytes(data)
13+
self._data = data
14+
15+
16+
class MocketEntry(MocketBytesEntry):
17+
def __init__(
18+
self,
19+
location: Address,
20+
responses: list[MocketBytesResponse | Exception | bytes | str | bool]
21+
| MocketBytesResponse
22+
| Exception
23+
| bytes
24+
| str
25+
| bool,
26+
) -> None:
27+
if not isinstance(responses, list):
2428
responses = [responses]
2529

26-
if not responses:
27-
self.responses = [self.response_cls(encode_to_bytes(""))]
28-
else:
29-
self.responses = []
30-
for r in responses:
31-
if not isinstance(r, BaseException) and not getattr(r, "data", False):
32-
if isinstance(r, str):
33-
r = encode_to_bytes(r)
34-
r = self.response_cls(r)
35-
self.responses.append(r)
36-
37-
def __repr__(self):
38-
return f"{self.__class__.__name__}(location={self.location})"
39-
40-
@staticmethod
41-
def can_handle(data):
42-
return True
43-
44-
def collect(self, data):
45-
req = self.request_cls(data)
46-
Mocket.collect(req)
47-
48-
def get_response(self):
49-
response = self.responses[self.response_index]
50-
if self.response_index < len(self.responses) - 1:
51-
self.response_index += 1
52-
53-
self._served = True
54-
55-
if isinstance(response, BaseException):
56-
raise response
30+
_responses = []
31+
for response in responses:
32+
if not isinstance(response, (MocketBytesResponse, Exception)):
33+
response = MocketBytesResponse(response)
34+
_responses.append(response)
5735

58-
return response.data
36+
super().__init__(address=location, responses=_responses)

mocket/core/mocket.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,19 @@
99
import mocket.core.inject
1010
from mocket.core.recording import MocketRecordStorage
1111

12-
# NOTE this is here for backwards-compat to keep old import-paths working
13-
# from mocket.socket import MocketSocket as MocketSocket
14-
1512
if TYPE_CHECKING:
16-
from mocket.compat.entry import MocketEntry
17-
from mocket.core.entry import MocketBaseEntry
13+
from mocket.core.entry import MocketBaseEntry, MocketBaseRequest
1814
from mocket.core.types import Address
1915

2016

2117
class Mocket:
2218
_socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {}
2319
_address: ClassVar[Address] = (None, None)
24-
_entries: ClassVar[dict[Address, list[MocketEntry | MocketBaseEntry]]] = (
25-
collections.defaultdict(list)
20+
_entries: ClassVar[dict[Address, list[MocketBaseEntry]]] = collections.defaultdict(
21+
list
2622
)
27-
_requests: ClassVar[list] = []
23+
_requests: ClassVar[list[MocketBaseRequest]] = []
24+
_last_entry: ClassVar[MocketBaseEntry | None] = None
2825
_record_storage: ClassVar[MocketRecordStorage | None] = None
2926

3027
@classmethod
@@ -73,18 +70,12 @@ def set_pair(cls, address: Address, pair: tuple[int, int]) -> None:
7370
cls._socket_pairs[address] = pair
7471

7572
@classmethod
76-
def register(cls, *entries: MocketEntry | MocketBaseEntry) -> None:
73+
def register(cls, *entries: MocketBaseEntry) -> None:
7774
for entry in entries:
78-
address = entry.location if hasattr(entry, "location") else entry.address
79-
cls._entries[address].append(entry)
75+
cls._entries[entry.address].append(entry)
8076

8177
@classmethod
82-
def get_entry(
83-
cls,
84-
host: str,
85-
port: int,
86-
data: bytes,
87-
) -> MocketEntry | MocketBaseEntry | None:
78+
def get_entry(cls, host: str, port: int, data) -> MocketBaseEntry | None:
8879
host = host or cls._address[0]
8980
port = port or cls._address[1]
9081
entries = cls._entries.get((host, port), [])
@@ -108,12 +99,13 @@ def reset(cls) -> None:
10899
cls._record_storage = None
109100

110101
@classmethod
111-
def last_request(cls):
102+
def last_request(cls) -> MocketBaseRequest | None:
112103
if cls.has_requests():
113104
return cls._requests[-1]
105+
return None
114106

115107
@classmethod
116-
def request_list(cls):
108+
def request_list(cls) -> list[MocketBaseRequest]:
117109
return cls._requests
118110

119111
@classmethod
@@ -140,9 +132,7 @@ def get_truesocket_recording_dir(cls) -> str | None:
140132
@classmethod
141133
def assert_fail_if_entries_not_served(cls) -> None:
142134
"""Mocket checks that all entries have been served at least once."""
143-
144-
def served(entry: MocketEntry | MocketBaseEntry) -> bool | None:
145-
return entry._served if hasattr(entry, "_served") else entry.served_response
146-
147-
if not all(served(entry) for entry in itertools.chain(*cls._entries.values())):
135+
if not all(
136+
entry.served_response for entry in itertools.chain(*cls._entries.values())
137+
):
148138
raise AssertionError("Some Mocket entries have not been served")

mocket/core/socket.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from typing_extensions import Self
1212

13-
from mocket.compat.entry import MocketEntry
13+
from mocket.core.entry import MocketBaseEntry
1414
from mocket.core.io import MocketSocketIO
1515
from mocket.core.mocket import Mocket
1616
from mocket.core.mode import MocketMode
@@ -167,7 +167,7 @@ def connect(self, address: Address) -> None:
167167
def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
168168
return self.io
169169

170-
def get_entry(self, data: bytes) -> MocketEntry | None:
170+
def get_entry(self, data: bytes) -> MocketBaseEntry | None:
171171
return Mocket.get_entry(self._host, self._port, data)
172172

173173
def sendall(self, data, entry=None, *args, **kwargs):

0 commit comments

Comments
 (0)