Skip to content

Commit 199e903

Browse files
committed
refactor: make injection code more readable and make backwards-compat more explicit
1 parent 0da2722 commit 199e903

File tree

6 files changed

+159
-127
lines changed

6 files changed

+159
-127
lines changed

mocket/inject.py

Lines changed: 124 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import os
4-
import socket
5-
import ssl
5+
from types import ModuleType
6+
from typing import Any
67

7-
import urllib3
8+
from packaging.version import Version
89

9-
try: # pragma: no cover
10-
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
10+
from mocket.utils import package_version, python_version
1111

12-
pyopenssl_override = True
13-
except ImportError:
14-
pyopenssl_override = False
1512

13+
def _replace(module: ModuleType, name: str, new_value: Any) -> None:
14+
module.__dict__[name] = new_value
15+
16+
17+
def _inject_stdlib_socket() -> None:
18+
import socket
1619

17-
def enable(
18-
namespace: str | None = None,
19-
truesocket_recording_dir: str | None = None,
20-
) -> None:
21-
from mocket.mocket import Mocket
2220
from mocket.socket import (
2321
MocketSocket,
2422
mock_create_connection,
@@ -27,99 +25,130 @@ def enable(
2725
mock_gethostname,
2826
mock_inet_pton,
2927
mock_socketpair,
30-
mock_urllib3_match_hostname,
3128
)
29+
30+
_replace(socket, "socket", MocketSocket)
31+
_replace(socket, "SocketType", MocketSocket)
32+
_replace(socket, "create_connection", mock_create_connection)
33+
_replace(socket, "getaddrinfo", mock_getaddrinfo)
34+
_replace(socket, "gethostbyname", mock_gethostbyname)
35+
_replace(socket, "gethostname", mock_gethostname)
36+
_replace(socket, "inet_pton", mock_inet_pton)
37+
_replace(socket, "socketpair", mock_socketpair)
38+
39+
40+
def _restore_stdlib_socket() -> None:
41+
import socket
42+
43+
from mocket.socket import (
44+
true_socket_create_connection,
45+
true_socket_getaddrinfo,
46+
true_socket_gethostbyname,
47+
true_socket_gethostname,
48+
true_socket_inet_pton,
49+
true_socket_socket,
50+
true_socket_socket_type,
51+
true_socket_socketpair,
52+
)
53+
54+
_replace(socket, "SocketType", true_socket_socket_type)
55+
_replace(socket, "create_connection", true_socket_create_connection)
56+
_replace(socket, "getaddrinfo", true_socket_getaddrinfo)
57+
_replace(socket, "gethostbyname", true_socket_gethostbyname)
58+
_replace(socket, "gethostname", true_socket_gethostname)
59+
_replace(socket, "inet_pton", true_socket_inet_pton)
60+
_replace(socket, "socket", true_socket_socket)
61+
_replace(socket, "socketpair", true_socket_socketpair)
62+
63+
64+
def _inject_stdlib_ssl() -> None:
65+
import ssl
66+
3267
from mocket.ssl.context import MocketSSLContext
3368

69+
_replace(ssl, "SSLContext", MocketSSLContext)
70+
71+
if python_version() < Version("3.12.0"):
72+
_replace(ssl, "wrap_socket", MocketSSLContext.wrap_socket)
73+
74+
75+
def _restore_stdlib_ssl() -> None:
76+
import ssl
77+
78+
from mocket.ssl.context import true_ssl_ssl_context
79+
80+
_replace(ssl, "SSLContext", true_ssl_ssl_context)
81+
82+
if python_version() < Version("3.12.0"):
83+
from mocket.ssl.context import true_ssl_wrap_socket
84+
85+
_replace(ssl, "wrap_socket", true_ssl_wrap_socket)
86+
87+
88+
def _inject_urllib3() -> None:
89+
import urllib3
90+
91+
from mocket.socket import mock_urllib3_match_hostname
92+
from mocket.ssl.context import MocketSSLContext
93+
94+
_replace(urllib3.util.ssl_, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
95+
_replace(urllib3.util, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
96+
_replace(urllib3.connection, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
97+
_replace(urllib3.connection, "match_hostname", mock_urllib3_match_hostname)
98+
99+
if package_version("urllib3") < Version("2.0.0"):
100+
_replace(urllib3.util.ssl_, "wrap_socket", MocketSSLContext.wrap_socket)
101+
102+
with contextlib.suppress(ImportError):
103+
from urllib3.contrib.pyopenssl import extract_from_urllib3
104+
105+
extract_from_urllib3()
106+
107+
108+
def _restore_urllib3() -> None:
109+
import urllib3
110+
111+
from mocket.socket import true_urllib3_match_hostname
112+
from mocket.ssl.context import true_urllib3_ssl_wrap_socket
113+
114+
_replace(urllib3.connection, "match_hostname", true_urllib3_match_hostname)
115+
_replace(urllib3.util.ssl_, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
116+
_replace(urllib3.util, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
117+
_replace(urllib3.connection, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
118+
119+
if package_version("urllib3") < Version("2.0.0"):
120+
from mocket.ssl.context import true_urllib3_wrap_socket
121+
122+
_replace(urllib3.util.ssl_, "wrap_socket", true_urllib3_wrap_socket)
123+
124+
with contextlib.suppress(ImportError):
125+
from urllib3.contrib.pyopenssl import inject_into_urllib3
126+
127+
inject_into_urllib3()
128+
129+
130+
def enable(
131+
namespace: str | None = None,
132+
truesocket_recording_dir: str | None = None,
133+
) -> None:
134+
_inject_stdlib_socket()
135+
_inject_stdlib_ssl()
136+
_inject_urllib3()
137+
138+
from mocket.mocket import Mocket
139+
34140
Mocket._namespace = namespace
35141
Mocket._truesocket_recording_dir = truesocket_recording_dir
36-
37142
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
38143
# JSON dumps will be saved here
39144
raise AssertionError
40145

41-
socket.socket = socket.__dict__["socket"] = MocketSocket
42-
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
43-
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
44-
socket.create_connection = socket.__dict__["create_connection"] = (
45-
mock_create_connection
46-
)
47-
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
48-
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
49-
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
50-
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
51-
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
52-
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
53-
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
54-
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
55-
MocketSSLContext.wrap_socket
56-
)
57-
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
58-
"ssl_wrap_socket"
59-
] = MocketSSLContext.wrap_socket
60-
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
61-
MocketSSLContext.wrap_socket
62-
)
63-
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
64-
"ssl_wrap_socket"
65-
] = MocketSSLContext.wrap_socket
66-
urllib3.connection.match_hostname = urllib3.connection.__dict__[
67-
"match_hostname"
68-
] = mock_urllib3_match_hostname
69-
if pyopenssl_override: # pragma: no cover
70-
# Take out the pyopenssl version - use the default implementation
71-
extract_from_urllib3()
72-
73146

74147
def disable() -> None:
148+
_restore_stdlib_socket()
149+
_restore_stdlib_ssl()
150+
_restore_urllib3()
151+
75152
from mocket.mocket import Mocket
76-
from mocket.socket import (
77-
true_create_connection,
78-
true_getaddrinfo,
79-
true_gethostbyname,
80-
true_gethostname,
81-
true_inet_pton,
82-
true_socket,
83-
true_socketpair,
84-
true_urllib3_match_hostname,
85-
)
86-
from mocket.ssl.context import (
87-
true_ssl_context,
88-
true_ssl_wrap_socket,
89-
true_urllib3_ssl_wrap_socket,
90-
true_urllib3_wrap_socket,
91-
)
92153

93-
socket.socket = socket.__dict__["socket"] = true_socket
94-
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
95-
socket.SocketType = socket.__dict__["SocketType"] = true_socket
96-
socket.create_connection = socket.__dict__["create_connection"] = (
97-
true_create_connection
98-
)
99-
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
100-
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
101-
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
102-
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
103-
if true_ssl_wrap_socket:
104-
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
105-
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
106-
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
107-
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
108-
true_urllib3_wrap_socket
109-
)
110-
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
111-
"ssl_wrap_socket"
112-
] = true_urllib3_ssl_wrap_socket
113-
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
114-
true_urllib3_ssl_wrap_socket
115-
)
116-
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
117-
"ssl_wrap_socket"
118-
] = true_urllib3_ssl_wrap_socket
119-
urllib3.connection.match_hostname = urllib3.connection.__dict__[
120-
"match_hostname"
121-
] = true_urllib3_match_hostname
122154
Mocket.reset()
123-
if pyopenssl_override: # pragma: no cover
124-
# Put the pyopenssl version back in place
125-
inject_into_urllib3()

mocket/socket.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from types import TracebackType
1212
from typing import Any, Type
1313

14-
import urllib3.connection
14+
import urllib3
1515
from typing_extensions import Self
1616

1717
from mocket.compat import decode_from_bytes, encode_to_bytes
@@ -27,13 +27,14 @@
2727
)
2828
from mocket.utils import hexdump, hexload
2929

30-
true_create_connection = socket.create_connection
31-
true_getaddrinfo = socket.getaddrinfo
32-
true_gethostbyname = socket.gethostbyname
33-
true_gethostname = socket.gethostname
34-
true_inet_pton = socket.inet_pton
35-
true_socket = socket.socket
36-
true_socketpair = socket.socketpair
30+
true_socket_socket = socket.socket
31+
true_socket_socket_type = socket.SocketType
32+
true_socket_create_connection = socket.create_connection
33+
true_socket_gethostname = socket.gethostname
34+
true_socket_gethostbyname = socket.gethostbyname
35+
true_socket_getaddrinfo = socket.getaddrinfo
36+
true_socket_socketpair = socket.socketpair
37+
true_socket_inet_pton = socket.inet_pton
3738
true_urllib3_match_hostname = urllib3.connection.match_hostname
3839

3940

@@ -106,7 +107,7 @@ def __init__(
106107
self._proto = proto
107108

108109
self._kwargs = kwargs
109-
self._true_socket = true_socket(family, type, proto)
110+
self._true_socket = true_socket_socket(family, type, proto)
110111

111112
self._buflen = 65536
112113
self._timeout: float | None = None
@@ -187,7 +188,7 @@ def getblocking(self) -> bool:
187188
return self.gettimeout() is None
188189

189190
def getsockname(self) -> _RetAddress:
190-
return true_gethostbyname(self._address[0]), self._address[1]
191+
return true_socket_gethostbyname(self._address[0]), self._address[1]
191192

192193
def connect(self, address: Address) -> None:
193194
self._address = self._host, self._port = address
@@ -295,7 +296,7 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
295296
# if not available, call the real sendall
296297
except KeyError:
297298
host, port = self._host, self._port
298-
host = true_gethostbyname(host)
299+
host = true_socket_gethostbyname(host)
299300

300301
with contextlib.suppress(OSError, ValueError):
301302
# already connected

mocket/ssl/context.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,24 @@
11
from __future__ import annotations
22

3-
import contextlib
43
import ssl
54
from typing import Any
65

7-
import urllib3.util.ssl_
6+
import urllib3
7+
from packaging.version import Version
88

99
from mocket.socket import MocketSocket
1010
from mocket.ssl.socket import MocketSSLSocket
11+
from mocket.utils import package_version, python_version
1112

12-
true_ssl_context = ssl.SSLContext
13+
true_ssl_ssl_context = ssl.SSLContext
1314

14-
true_ssl_wrap_socket = None
15-
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
16-
true_urllib3_wrap_socket = None
17-
18-
with contextlib.suppress(ImportError):
19-
# from Py3.12 it's only under SSLContext
20-
from ssl import wrap_socket as ssl_wrap_socket
15+
if python_version() < Version("3.12.0"):
16+
true_ssl_wrap_socket = ssl.wrap_socket
2117

22-
true_ssl_wrap_socket = ssl_wrap_socket
23-
24-
with contextlib.suppress(ImportError):
25-
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
18+
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
2619

27-
true_urllib3_wrap_socket = urllib3_wrap_socket
20+
if package_version("urllib3") < Version("2.0.0"):
21+
true_urllib3_wrap_socket = urllib3.util.ssl_.wrap_socket
2822

2923

3024
class _MocketSSLContext:

mocket/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
from __future__ import annotations
22

33
import binascii
4+
import importlib.metadata
5+
import platform
46
import ssl
57
from typing import Callable
68

9+
import packaging.version
10+
711
from mocket.compat import decode_from_bytes, encode_to_bytes
812

9-
# NOTE this is here for backwards-compat to keep old import-paths working
10-
from mocket.io import MocketSocketIO as MocketSocketCore
13+
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
1114

12-
# NOTE this is here for backwards-compat to keep old import-paths working
13-
from mocket.mode import MocketMode
1415

15-
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
16+
def python_version() -> packaging.version.Version:
17+
return packaging.version.Version(platform.python_version())
18+
19+
20+
def package_version(package_name: str) -> packaging.version.Version:
21+
version = importlib.metadata.version(package_name)
22+
return packaging.version.parse(version)
1623

1724

1825
def hexdump(binary_string: bytes) -> str:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"decorator>=4.0.0",
3232
"urllib3>=1.25.3",
3333
"h11",
34+
"packaging>=24.2",
3435
]
3536
dynamic = ["version"]
3637

tests/test_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mocket import Mocketizer, mocketize
55
from mocket.exceptions import StrictMocketException
66
from mocket.mockhttp import Entry, Response
7-
from mocket.utils import MocketMode
7+
from mocket.mode import MocketMode
88

99

1010
@mocketize(strict_mode=True)

0 commit comments

Comments
 (0)