Skip to content

Commit 5833f65

Browse files
committed
revert to singleton
1 parent cd59dbc commit 5833f65

12 files changed

Lines changed: 133 additions & 131 deletions

File tree

mocket/entry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections.abc
22

3-
import mocket.state
43
from mocket.compat import encode_to_bytes
4+
from mocket.mocket import Mocket
55

66

77
class MocketEntry:
@@ -43,7 +43,7 @@ def can_handle(data):
4343

4444
def collect(self, data):
4545
req = self.request_cls(data)
46-
mocket.state.state.collect(req)
46+
Mocket.collect(req)
4747

4848
def get_response(self):
4949
response = self.responses[self.response_index]

mocket/inject.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from urllib3.connection import match_hostname as urllib3_match_hostname
99
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
1010

11-
import mocket.state
12-
1311
try:
1412
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
1513
except ImportError:
@@ -44,11 +42,12 @@ def enable(
4442
namespace: str | None = None,
4543
truesocket_recording_dir: str | None = None,
4644
) -> None:
45+
from mocket.mocket import Mocket
4746
from mocket.socket import MocketSocket, create_connection, socketpair
4847
from mocket.ssl import FakeSSLContext
4948

50-
mocket.state.state._namespace = namespace
51-
mocket.state.state._truesocket_recording_dir = truesocket_recording_dir
49+
Mocket._namespace = namespace
50+
Mocket._truesocket_recording_dir = truesocket_recording_dir
5251

5352
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
5453
# JSON dumps will be saved here
@@ -92,6 +91,8 @@ def enable(
9291

9392

9493
def disable() -> None:
94+
from mocket.mocket import Mocket
95+
9596
socket.socket = socket.__dict__["socket"] = true_socket
9697
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
9798
socket.SocketType = socket.__dict__["SocketType"] = true_socket
@@ -121,7 +122,7 @@ def disable() -> None:
121122
urllib3.connection.match_hostname = urllib3.connection.__dict__[
122123
"match_hostname"
123124
] = true_urllib3_match_hostname
124-
mocket.state.state.reset()
125+
Mocket.reset()
125126
if pyopenssl_override: # pragma: no cover
126127
# Put the pyopenssl version back in place
127128
inject_into_urllib3()

mocket/io.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import io
22
import os
33

4-
import mocket.state
4+
from mocket.mocket import Mocket
55

66

77
class MocketSocketCore(io.BytesIO):
@@ -12,6 +12,6 @@ def __init__(self, address) -> None:
1212
def write(self, content):
1313
super().write(content)
1414

15-
_, w_fd = mocket.state.state.get_pair(self._address)
15+
_, w_fd = Mocket.get_pair(self._address)
1616
if w_fd:
1717
os.write(w_fd, content)

mocket/mocket.py

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,103 @@
1-
from typing import cast
1+
from __future__ import annotations
2+
3+
import collections
4+
import itertools
5+
import os
6+
from typing import TYPE_CHECKING, ClassVar
27

38
import mocket.inject
4-
import mocket.state
59

610
# NOTE this is here for backwards-compat to keep old import-paths working
7-
from mocket.socket import MocketSocket as MocketSocket
11+
# from mocket.socket import MocketSocket as MocketSocket
12+
13+
if TYPE_CHECKING:
14+
from mocket.entry import MocketEntry
15+
from mocket.types import Address
16+
17+
18+
class Mocket:
19+
_socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {}
20+
_address: ClassVar[Address] = (None, None)
21+
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
22+
_requests: ClassVar[list] = []
23+
_namespace: ClassVar[str] = str(id(_entries))
24+
_truesocket_recording_dir: ClassVar[str | None] = None
25+
26+
enable = mocket.inject.enable
27+
disable = mocket.inject.disable
28+
29+
@classmethod
30+
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
31+
"""
32+
Given the id() of the caller, return a pair of file descriptors
33+
as a tuple of two integers: (<read_fd>, <write_fd>)
34+
"""
35+
return cls._socket_pairs.get(address, (None, None))
36+
37+
@classmethod
38+
def set_pair(cls, address: Address, pair: tuple[int, int]) -> None:
39+
"""
40+
Store a pair of file descriptors under the key `id_`
41+
as a tuple of two integers: (<read_fd>, <write_fd>)
42+
"""
43+
cls._socket_pairs[address] = pair
44+
45+
@classmethod
46+
def register(cls, *entries: MocketEntry) -> None:
47+
for entry in entries:
48+
cls._entries[entry.location].append(entry)
49+
50+
@classmethod
51+
def get_entry(cls, host: str, port: int, data) -> MocketEntry | None:
52+
host = host or cls._address[0]
53+
port = port or cls._address[1]
54+
entries = cls._entries.get((host, port), [])
55+
for entry in entries:
56+
if entry.can_handle(data):
57+
return entry
58+
return None
59+
60+
@classmethod
61+
def collect(cls, data) -> None:
62+
cls._requests.append(data)
63+
64+
@classmethod
65+
def reset(cls) -> None:
66+
for r_fd, w_fd in cls._socket_pairs.values():
67+
os.close(r_fd)
68+
os.close(w_fd)
69+
cls._socket_pairs = {}
70+
cls._entries = collections.defaultdict(list)
71+
cls._requests = []
72+
73+
@classmethod
74+
def last_request(cls):
75+
if cls.has_requests():
76+
return cls._requests[-1]
77+
78+
@classmethod
79+
def request_list(cls):
80+
return cls._requests
81+
82+
@classmethod
83+
def remove_last_request(cls) -> None:
84+
if cls.has_requests():
85+
del cls._requests[-1]
886

87+
@classmethod
88+
def has_requests(cls) -> bool:
89+
return bool(cls.request_list())
990

10-
class _Mocket(mocket.state.MocketState):
11-
def __init__(self) -> None:
12-
self.enable = mocket.inject.enable
13-
self.disable = mocket.inject.disable
91+
@classmethod
92+
def get_namespace(cls) -> str:
93+
return cls._namespace
1494

95+
@classmethod
96+
def get_truesocket_recording_dir(cls) -> str | None:
97+
return cls._truesocket_recording_dir
1598

16-
Mocket = cast(_Mocket, mocket.state.state)
17-
Mocket.enable = mocket.inject.enable
18-
Mocket.disable = mocket.inject.disable
99+
@classmethod
100+
def assert_fail_if_entries_not_served(cls) -> None:
101+
"""Mocket checks that all entries have been served at least once."""
102+
if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
103+
raise AssertionError("Some Mocket entries have not been served")

mocket/mockhttp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from h11 import SERVER, Connection, Data
88
from h11 import Request as H11Request
99

10-
import mocket.state
1110
from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes
1211
from mocket.entry import MocketEntry
12+
from mocket.mocket import Mocket
1313

1414
STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()}
1515
CRLF = "\r\n"
@@ -165,7 +165,7 @@ def collect(self, data):
165165

166166
decoded_data = decode_from_bytes(data)
167167
if not decoded_data.startswith(Entry.METHODS):
168-
mocket.state.state.remove_last_request()
168+
Mocket.remove_last_request()
169169
self._sent_data += data
170170
consume_response = False
171171
else:
@@ -188,7 +188,7 @@ def can_handle(self, data):
188188
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
189189
method, path, _ = self._parse_requestline(requestline)
190190
except ValueError:
191-
return self is getattr(mocket.state.state, "_last_entry", None)
191+
return self is getattr(Mocket, "_last_entry", None)
192192

193193
uri = urlsplit(path)
194194
can_handle = uri.path == self.path and method == self.method
@@ -198,7 +198,7 @@ def can_handle(self, data):
198198
self.query, **kw
199199
)
200200
if can_handle:
201-
mocket.state.state._last_entry = self
201+
Mocket._last_entry = self
202202
return can_handle
203203

204204
@staticmethod
@@ -234,7 +234,7 @@ def register(cls, method, uri, *responses, **config):
234234
if config["add_trailing_slash"] and not urlsplit(uri).path:
235235
uri += "/"
236236

237-
mocket.state.state.register(
237+
Mocket.register(
238238
cls(uri, method, responses, match_querystring=config["match_querystring"])
239239
)
240240

mocket/mockredis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from itertools import chain
22

3-
import mocket.state
43
from mocket.compat import (
54
decode_from_bytes,
65
encode_to_bytes,
76
shsplit,
87
)
98
from mocket.entry import MocketEntry
9+
from mocket.mocket import Mocket
1010

1111

1212
class Request:
@@ -80,7 +80,7 @@ def register(cls, addr, command, *responses):
8080
r if isinstance(r, BaseException) else cls.response_cls(r)
8181
for r in responses
8282
]
83-
mocket.state.state.register(cls(addr, command, responses))
83+
Mocket.register(cls(addr, command, responses))
8484

8585
@classmethod
8686
def register_response(cls, command, response, addr=None):

mocket/mode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from typing import TYPE_CHECKING, Any, ClassVar
44

5-
import mocket.state
65
from mocket.exceptions import StrictMocketException
6+
from mocket.mocket import Mocket
77

88
if TYPE_CHECKING: # pragma: no cover
99
from typing import NoReturn
@@ -34,7 +34,7 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool:
3434
def raise_not_allowed() -> NoReturn:
3535
current_entries = [
3636
(location, "\n ".join(map(str, entries)))
37-
for location, entries in mocket.state.state._entries.items()
37+
for location, entries in Mocket._entries.items()
3838
]
3939
formatted_entries = "\n".join(
4040
[f" {location}:\n {entries}" for location, entries in current_entries]

mocket/plugins/httpretty/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import mocket.inject
2-
import mocket.state
32
from mocket import mocketize
43
from mocket.async_mocket import async_mocketize
54
from mocket.compat import ENCODING
5+
from mocket.mocket import Mocket
66
from mocket.mockhttp import Entry as MocketHttpEntry
77
from mocket.mockhttp import Request as MocketHttpRequest
88
from mocket.mockhttp import Response as MocketHttpResponse
@@ -49,7 +49,7 @@ class Entry(MocketHttpEntry):
4949

5050
enable = mocket.inject.enable
5151
disable = mocket.inject.disable
52-
reset = mocket.state.state.reset
52+
reset = Mocket.reset
5353

5454
GET = Entry.GET
5555
PUT = Entry.PUT
@@ -104,9 +104,9 @@ class MocketHTTPretty:
104104

105105
def __getattr__(self, name):
106106
if name == "last_request":
107-
return mocket.state.state.last_request()
107+
return Mocket.last_request()
108108
if name == "latest_requests":
109-
return mocket.state.state.request_list()
109+
return Mocket.request_list()
110110
return getattr(Entry, name)
111111

112112

mocket/plugins/pook_mock_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
MockEngine = object
55

66
import mocket.inject
7-
import mocket.state
7+
from mocket.mocket import Mocket
88
from mocket.mockhttp import Entry, Response
99

1010

@@ -36,7 +36,7 @@ def single_register(
3636
[Response(body=body, status=status, headers=headers)],
3737
match_querystring=match_querystring,
3838
)
39-
mocket.state.state.register(entry)
39+
Mocket.register(entry)
4040
return entry
4141

4242

0 commit comments

Comments
 (0)