|
1 | | -from typing import cast |
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import collections |
| 4 | +import itertools |
| 5 | +import os |
| 6 | +from typing import TYPE_CHECKING, ClassVar |
2 | 7 |
|
3 | 8 | import mocket.inject |
4 | | -import mocket.state |
5 | 9 |
|
6 | 10 | # NOTE this is here for backwards-compat to keep old import-paths working |
7 | | -from mocket.socket import MocketSocket as MocketSocket |
| 11 | +# from mocket.socket import MocketSocket as MocketSocket |
| 12 | + |
| 13 | +if TYPE_CHECKING: |
| 14 | + from mocket.entry import MocketEntry |
| 15 | + from mocket.types import Address |
| 16 | + |
| 17 | + |
| 18 | +class Mocket: |
| 19 | + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} |
| 20 | + _address: ClassVar[Address] = (None, None) |
| 21 | + _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) |
| 22 | + _requests: ClassVar[list] = [] |
| 23 | + _namespace: ClassVar[str] = str(id(_entries)) |
| 24 | + _truesocket_recording_dir: ClassVar[str | None] = None |
| 25 | + |
| 26 | + enable = mocket.inject.enable |
| 27 | + disable = mocket.inject.disable |
| 28 | + |
| 29 | + @classmethod |
| 30 | + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: |
| 31 | + """ |
| 32 | + Given the id() of the caller, return a pair of file descriptors |
| 33 | + as a tuple of two integers: (<read_fd>, <write_fd>) |
| 34 | + """ |
| 35 | + return cls._socket_pairs.get(address, (None, None)) |
| 36 | + |
| 37 | + @classmethod |
| 38 | + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: |
| 39 | + """ |
| 40 | + Store a pair of file descriptors under the key `id_` |
| 41 | + as a tuple of two integers: (<read_fd>, <write_fd>) |
| 42 | + """ |
| 43 | + cls._socket_pairs[address] = pair |
| 44 | + |
| 45 | + @classmethod |
| 46 | + def register(cls, *entries: MocketEntry) -> None: |
| 47 | + for entry in entries: |
| 48 | + cls._entries[entry.location].append(entry) |
| 49 | + |
| 50 | + @classmethod |
| 51 | + def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: |
| 52 | + host = host or cls._address[0] |
| 53 | + port = port or cls._address[1] |
| 54 | + entries = cls._entries.get((host, port), []) |
| 55 | + for entry in entries: |
| 56 | + if entry.can_handle(data): |
| 57 | + return entry |
| 58 | + return None |
| 59 | + |
| 60 | + @classmethod |
| 61 | + def collect(cls, data) -> None: |
| 62 | + cls._requests.append(data) |
| 63 | + |
| 64 | + @classmethod |
| 65 | + def reset(cls) -> None: |
| 66 | + for r_fd, w_fd in cls._socket_pairs.values(): |
| 67 | + os.close(r_fd) |
| 68 | + os.close(w_fd) |
| 69 | + cls._socket_pairs = {} |
| 70 | + cls._entries = collections.defaultdict(list) |
| 71 | + cls._requests = [] |
| 72 | + |
| 73 | + @classmethod |
| 74 | + def last_request(cls): |
| 75 | + if cls.has_requests(): |
| 76 | + return cls._requests[-1] |
| 77 | + |
| 78 | + @classmethod |
| 79 | + def request_list(cls): |
| 80 | + return cls._requests |
| 81 | + |
| 82 | + @classmethod |
| 83 | + def remove_last_request(cls) -> None: |
| 84 | + if cls.has_requests(): |
| 85 | + del cls._requests[-1] |
8 | 86 |
|
| 87 | + @classmethod |
| 88 | + def has_requests(cls) -> bool: |
| 89 | + return bool(cls.request_list()) |
9 | 90 |
|
10 | | -class _Mocket(mocket.state.MocketState): |
11 | | - def __init__(self) -> None: |
12 | | - self.enable = mocket.inject.enable |
13 | | - self.disable = mocket.inject.disable |
| 91 | + @classmethod |
| 92 | + def get_namespace(cls) -> str: |
| 93 | + return cls._namespace |
14 | 94 |
|
| 95 | + @classmethod |
| 96 | + def get_truesocket_recording_dir(cls) -> str | None: |
| 97 | + return cls._truesocket_recording_dir |
15 | 98 |
|
16 | | -Mocket = cast(_Mocket, mocket.state.state) |
17 | | -Mocket.enable = mocket.inject.enable |
18 | | -Mocket.disable = mocket.inject.disable |
| 99 | + @classmethod |
| 100 | + def assert_fail_if_entries_not_served(cls) -> None: |
| 101 | + """Mocket checks that all entries have been served at least once.""" |
| 102 | + if not all(entry._served for entry in itertools.chain(*cls._entries.values())): |
| 103 | + raise AssertionError("Some Mocket entries have not been served") |
0 commit comments