diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c4481efc..cdb55fe0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', ' 3.13', 'pypy3.10'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10'] steps: - uses: actions/checkout@v4 diff --git a/mocket/inject.py b/mocket/inject.py index 469ab30b..866ee563 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import os import socket import ssl from types import ModuleType @@ -23,10 +22,7 @@ def _restore(module: ModuleType, name: str) -> None: module.__dict__[name] = original_value -def enable( - namespace: str | None = None, - truesocket_recording_dir: str | None = None, -) -> None: +def enable() -> None: from mocket.socket import ( MocketSocket, mock_create_connection, @@ -73,14 +69,6 @@ def enable( extract_from_urllib3() - from mocket.mocket import Mocket - - Mocket._namespace = namespace - Mocket._truesocket_recording_dir = truesocket_recording_dir - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): - # JSON dumps will be saved here - raise AssertionError - def disable() -> None: for module, name in list(_patches_restore.keys()): @@ -90,7 +78,3 @@ def disable() -> None: from urllib3.contrib.pyopenssl import inject_into_urllib3 inject_into_urllib3() - - from mocket.mocket import Mocket - - Mocket.reset() diff --git a/mocket/mocket.py b/mocket/mocket.py index 3476902d..a01a7b46 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -3,9 +3,11 @@ import collections import itertools import os +from pathlib import Path from typing import TYPE_CHECKING, ClassVar import mocket.inject +from mocket.recording import MocketRecordStorage # NOTE this is here for backwards-compat to keep old import-paths working # from mocket.socket import MocketSocket as MocketSocket @@ -20,11 +22,36 @@ class Mocket: _address: ClassVar[Address] = (None, None) _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) _requests: ClassVar[list] = [] - _namespace: ClassVar[str] = str(id(_entries)) - _truesocket_recording_dir: ClassVar[str | None] = None + _record_storage: ClassVar[MocketRecordStorage | None] = None - enable = mocket.inject.enable - disable = mocket.inject.disable + @classmethod + def enable( + cls, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + ) -> None: + if namespace is None: + namespace = str(id(cls._entries)) + + if truesocket_recording_dir is not None: + recording_dir = Path(truesocket_recording_dir) + + if not recording_dir.is_dir(): + # JSON dumps will be saved here + raise AssertionError + + cls._record_storage = MocketRecordStorage( + directory=recording_dir, + namespace=namespace, + ) + + mocket.inject.enable() + + @classmethod + def disable(cls) -> None: + cls.reset() + + mocket.inject.disable() @classmethod def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: @@ -69,6 +96,7 @@ def reset(cls) -> None: cls._socket_pairs = {} cls._entries = collections.defaultdict(list) cls._requests = [] + cls._record_storage = None @classmethod def last_request(cls): @@ -89,12 +117,16 @@ def has_requests(cls) -> bool: return bool(cls.request_list()) @classmethod - def get_namespace(cls) -> str: - return cls._namespace + def get_namespace(cls) -> str | None: + if not cls._record_storage: + return None + return cls._record_storage.namespace @classmethod def get_truesocket_recording_dir(cls) -> str | None: - return cls._truesocket_recording_dir + if not cls._record_storage: + return None + return str(cls._record_storage.directory) @classmethod def assert_fail_if_entries_not_served(cls) -> None: diff --git a/mocket/recording.py b/mocket/recording.py new file mode 100644 index 00000000..97d2adbe --- /dev/null +++ b/mocket/recording.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import contextlib +import hashlib +import json +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path + +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.types import Address +from mocket.utils import hexdump, hexload + +hash_function = hashlib.md5 + +with contextlib.suppress(ImportError): + from xxhash_cffi import xxh32 as xxhash_cffi_xxh32 + + hash_function = xxhash_cffi_xxh32 + +with contextlib.suppress(ImportError): + from xxhash import xxh32 as xxhash_xxh32 + + hash_function = xxhash_xxh32 + + +def _hash_prepare_request(data: bytes) -> bytes: + _data = decode_from_bytes(data) + return encode_to_bytes("".join(sorted(_data.split("\r\n")))) + + +def _hash_request(data: bytes) -> str: + _data = _hash_prepare_request(data) + return hash_function(_data).hexdigest() + + +def _hash_request_fallback(data: bytes) -> str: + _data = _hash_prepare_request(data) + return hashlib.md5(_data).hexdigest() + + +@dataclass +class MocketRecord: + host: str + port: int + request: bytes + response: bytes + + +class MocketRecordStorage: + def __init__(self, directory: Path, namespace: str) -> None: + self._directory = directory + self._namespace = namespace + self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = ( + defaultdict(defaultdict) + ) + + self._load() + + @property + def directory(self) -> Path: + return self._directory + + @property + def namespace(self) -> str: + return self._namespace + + @property + def file(self) -> Path: + return self._directory / f"{self._namespace}.json" + + def _load(self) -> None: + if not self.file.exists(): + return + + json_data = self.file.read_text() + records = json.loads(json_data) + for host, port_signature_record in records.items(): + for port, signature_record in port_signature_record.items(): + for signature, record in signature_record.items(): + # NOTE backward-compat + try: + request_data = hexload(record["request"]) + except ValueError: + request_data = record["request"] + + self._records[(host, int(port))][signature] = MocketRecord( + host=host, + port=port, + request=request_data, + response=hexload(record["response"]), + ) + + def _save(self) -> None: + data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict( + lambda: defaultdict(defaultdict) + ) + for address, signature_record in self._records.items(): + host, port = address + for signature, record in signature_record.items(): + data[host][str(port)][signature] = dict( + request=decode_from_bytes(record.request), + response=hexdump(record.response), + ) + + json_data = json.dumps(data, indent=4, sort_keys=True) + self.file.parent.mkdir(exist_ok=True) + self.file.write_text(json_data) + + def get_records(self, address: Address) -> list[MocketRecord]: + return list(self._records[address].values()) + + def get_record(self, address: Address, request: bytes) -> MocketRecord | None: + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + return self._records[address].get(request_signature_fallback) + + request_signature = _hash_request(request) + if request_signature in self._records[address]: + return self._records[address][request_signature] + + return None + + def put_record( + self, + address: Address, + request: bytes, + response: bytes, + ) -> None: + host, port = address + record = MocketRecord( + host=host, + port=port, + request=request, + response=response, + ) + + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + self._records[address][request_signature_fallback] = record + return + + request_signature = _hash_request(request) + self._records[address][request_signature] = record + self._save() diff --git a/mocket/socket.py b/mocket/socket.py index 9480d365..3b1862e2 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -2,18 +2,14 @@ import contextlib import errno -import hashlib -import json import os import select import socket -from json.decoder import JSONDecodeError from types import TracebackType from typing import Any, Type from typing_extensions import Self -from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.entry import MocketEntry from mocket.io import MocketSocketIO from mocket.mocket import Mocket @@ -24,21 +20,11 @@ WriteableBuffer, _RetAddress, ) -from mocket.utils import hexdump, hexload true_gethostbyname = socket.gethostbyname true_socket = socket.socket -xxh32 = None -try: - from xxhash import xxh32 -except ImportError: # pragma: no cover - with contextlib.suppress(ImportError): - from xxhash_cffi import xxh32 -hasher = xxh32 or hashlib.md5 - - def mock_create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: @@ -77,10 +63,6 @@ def mock_socketpair(*args, **kwargs): return _socket.socketpair(*args, **kwargs) -def _hash_request(h, req): - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() - - class MocketSocket: def __init__( self, @@ -235,87 +217,47 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: exc.args = (0,) raise exc - def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: - if not MocketMode().is_allowed((self._host, self._port)): + def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes: + if not MocketMode().is_allowed(self._address): MocketMode.raise_not_allowed() - req = decode_from_bytes(data) - # make request unique again - req_signature = _hash_request(hasher, req) - # port should be always a string - port = str(self._port) - - # prepare responses dictionary - responses = {} - - if Mocket.get_truesocket_recording_dir(): - path = os.path.join( - Mocket.get_truesocket_recording_dir(), - Mocket.get_namespace() + ".json", + # try to get the response from recordings + if Mocket._record_storage: + record = Mocket._record_storage.get_record( + address=self._address, + request=data, ) - # check if there's already a recorded session dumped to a JSON file - try: - with open(path) as f: - responses = json.load(f) - # if not, create a new dictionary - except (FileNotFoundError, JSONDecodeError): - pass - - try: - try: - response_dict = responses[self._host][port][req_signature] - except KeyError: - if hasher is not hashlib.md5: - # Fallback for backwards compatibility - req_signature = _hash_request(hashlib.md5, req) - response_dict = responses[self._host][port][req_signature] - else: - raise - except KeyError: - # preventing next KeyError exceptions - responses.setdefault(self._host, {}) - responses[self._host].setdefault(port, {}) - responses[self._host][port].setdefault(req_signature, {}) - response_dict = responses[self._host][port][req_signature] - - # try to get the response from the dictionary - try: - encoded_response = hexload(response_dict["response"]) - # if not available, call the real sendall - except KeyError: - host, port = self._host, self._port - host = true_gethostbyname(host) - - with contextlib.suppress(OSError, ValueError): - # already connected - self._true_socket.connect((host, port)) - self._true_socket.sendall(data, *args, **kwargs) - encoded_response = b"" - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 - while True: - more_to_read = select.select([self._true_socket], [], [], 0.1)[0] - if not more_to_read and encoded_response: - break - new_content = self._true_socket.recv(self._buflen) - if not new_content: - break - encoded_response += new_content - - # dump the resulting dictionary to a JSON file - if Mocket.get_truesocket_recording_dir(): - # update the dictionary with request and response lines - response_dict["request"] = req - response_dict["response"] = hexdump(encoded_response) - - with open(path, mode="w") as f: - f.write( - decode_from_bytes( - json.dumps(responses, indent=4, sort_keys=True) - ) - ) - - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO - return encoded_response + if record is not None: + return record.response + + host, port = self._address + host = true_gethostbyname(host) + + with contextlib.suppress(OSError, ValueError): + # already connected + self._true_socket.connect((host, port)) + + self._true_socket.sendall(data, *args, **kwargs) + response = b"" + # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 + while True: + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] + if not more_to_read and response: + break + new_content = self._true_socket.recv(self._buflen) + if not new_content: + break + response += new_content + + # store request+response in recordings + if Mocket._record_storage: + Mocket._record_storage.put_record( + address=self._address, + request=data, + response=response, + ) + + return response def send( self, diff --git a/mocket/utils.py b/mocket/utils.py index 59403954..31557a58 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -6,12 +6,6 @@ from mocket.compat import decode_from_bytes, encode_to_bytes -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.io import MocketSocketIO as MocketSocketCore - -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.mode import MocketMode - SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 @@ -30,7 +24,10 @@ def hexload(string: str) -> bytes: True """ string_no_spaces = "".join(string.split()) - return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + try: + return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + except binascii.Error as e: + raise ValueError from e def get_mocketize(wrapper_: Callable) -> Callable: @@ -45,8 +42,6 @@ def get_mocketize(wrapper_: Callable) -> Callable: __all__ = ( - "MocketMode", - "MocketSocketCore", "SSL_PROTOCOL", "get_mocketize", "hexdump",