|
1 | | -import collections |
2 | | -import itertools |
3 | | -import os |
4 | | -from typing import Optional, Tuple |
| 1 | +from typing import cast |
5 | 2 |
|
6 | 3 | import mocket.inject |
| 4 | +import mocket.state |
7 | 5 |
|
8 | 6 | # NOTE this is here for backwards-compat to keep old import-paths working |
9 | 7 | from mocket.socket import MocketSocket as MocketSocket |
10 | 8 |
|
11 | 9 |
|
12 | | -class Mocket: |
13 | | - _socket_pairs = {} |
14 | | - _address = (None, None) |
15 | | - _entries = collections.defaultdict(list) |
16 | | - _requests = [] |
17 | | - _namespace = str(id(_entries)) |
18 | | - _truesocket_recording_dir = None |
| 10 | +class _Mocket(mocket.state.MocketState): |
| 11 | + def __init__(self) -> None: |
| 12 | + self.enable = mocket.inject.enable |
| 13 | + self.disable = mocket.inject.disable |
19 | 14 |
|
20 | | - @classmethod |
21 | | - def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: |
22 | | - """ |
23 | | - Given the id() of the caller, return a pair of file descriptors |
24 | | - as a tuple of two integers: (<read_fd>, <write_fd>) |
25 | | - """ |
26 | | - return cls._socket_pairs.get(address, (None, None)) |
27 | 15 |
|
28 | | - @classmethod |
29 | | - def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: |
30 | | - """ |
31 | | - Store a pair of file descriptors under the key `id_` |
32 | | - as a tuple of two integers: (<read_fd>, <write_fd>) |
33 | | - """ |
34 | | - cls._socket_pairs[address] = pair |
35 | | - |
36 | | - @classmethod |
37 | | - def register(cls, *entries): |
38 | | - for entry in entries: |
39 | | - cls._entries[entry.location].append(entry) |
40 | | - |
41 | | - @classmethod |
42 | | - def get_entry(cls, host, port, data): |
43 | | - host = host or Mocket._address[0] |
44 | | - port = port or Mocket._address[1] |
45 | | - entries = cls._entries.get((host, port), []) |
46 | | - for entry in entries: |
47 | | - if entry.can_handle(data): |
48 | | - return entry |
49 | | - |
50 | | - @classmethod |
51 | | - def collect(cls, data): |
52 | | - cls.request_list().append(data) |
53 | | - |
54 | | - @classmethod |
55 | | - def reset(cls): |
56 | | - for r_fd, w_fd in cls._socket_pairs.values(): |
57 | | - os.close(r_fd) |
58 | | - os.close(w_fd) |
59 | | - cls._socket_pairs = {} |
60 | | - cls._entries = collections.defaultdict(list) |
61 | | - cls._requests = [] |
62 | | - |
63 | | - @classmethod |
64 | | - def last_request(cls): |
65 | | - if cls.has_requests(): |
66 | | - return cls.request_list()[-1] |
67 | | - |
68 | | - @classmethod |
69 | | - def request_list(cls): |
70 | | - return cls._requests |
71 | | - |
72 | | - @classmethod |
73 | | - def remove_last_request(cls): |
74 | | - if cls.has_requests(): |
75 | | - del cls._requests[-1] |
76 | | - |
77 | | - @classmethod |
78 | | - def has_requests(cls): |
79 | | - return bool(cls.request_list()) |
80 | | - |
81 | | - @classmethod |
82 | | - def get_namespace(cls): |
83 | | - return cls._namespace |
84 | | - |
85 | | - @staticmethod |
86 | | - def enable(namespace=None, truesocket_recording_dir=None): |
87 | | - mocket.inject.enable(namespace, truesocket_recording_dir) |
88 | | - |
89 | | - @staticmethod |
90 | | - def disable(): |
91 | | - mocket.inject.disable() |
92 | | - |
93 | | - @classmethod |
94 | | - def get_truesocket_recording_dir(cls): |
95 | | - return cls._truesocket_recording_dir |
96 | | - |
97 | | - @classmethod |
98 | | - def assert_fail_if_entries_not_served(cls): |
99 | | - """Mocket checks that all entries have been served at least once.""" |
100 | | - if not all(entry._served for entry in itertools.chain(*cls._entries.values())): |
101 | | - raise AssertionError("Some Mocket entries have not been served") |
| 16 | +Mocket = cast(_Mocket, mocket.state.state) |
| 17 | +Mocket.enable = mocket.inject.enable |
| 18 | +Mocket.disable = mocket.inject.disable |
0 commit comments