Skip to content

Commit 8b30c00

Browse files
committed
refactor: make injection code more readable and make backwards-compat explicit
1 parent 0da2722 commit 8b30c00

File tree

6 files changed

+146
-114
lines changed

6 files changed

+146
-114
lines changed

mocket/inject.py

Lines changed: 122 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import os
4-
import socket
5-
import ssl
5+
from types import ModuleType
6+
from typing import Any
67

7-
import urllib3
8+
from packaging.version import Version
89

9-
try: # pragma: no cover
10-
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
10+
from mocket.utils import package_version, python_version
1111

12-
pyopenssl_override = True
13-
except ImportError:
14-
pyopenssl_override = False
1512

13+
def _replace(module: ModuleType, name: str, new_value: Any) -> None:
14+
module.__dict__[name] = new_value
15+
16+
17+
def _inject_stdlib_socket() -> None:
18+
import socket
1619

17-
def enable(
18-
namespace: str | None = None,
19-
truesocket_recording_dir: str | None = None,
20-
) -> None:
21-
from mocket.mocket import Mocket
2220
from mocket.socket import (
2321
MocketSocket,
2422
mock_create_connection,
@@ -27,99 +25,130 @@ def enable(
2725
mock_gethostname,
2826
mock_inet_pton,
2927
mock_socketpair,
30-
mock_urllib3_match_hostname,
3128
)
32-
from mocket.ssl.context import MocketSSLContext
3329

34-
Mocket._namespace = namespace
35-
Mocket._truesocket_recording_dir = truesocket_recording_dir
30+
_replace(socket, "socket", MocketSocket)
31+
_replace(socket, "create_connection", mock_create_connection)
32+
_replace(socket, "getaddrinfo", mock_getaddrinfo)
33+
_replace(socket, "gethostbyname", mock_gethostbyname)
34+
_replace(socket, "gethostname", mock_gethostname)
35+
_replace(socket, "inet_pton", mock_inet_pton)
36+
_replace(socket, "SocketType", MocketSocket)
37+
_replace(socket, "socketpair", mock_socketpair)
3638

37-
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
38-
# JSON dumps will be saved here
39-
raise AssertionError
4039

41-
socket.socket = socket.__dict__["socket"] = MocketSocket
42-
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
43-
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
44-
socket.create_connection = socket.__dict__["create_connection"] = (
45-
mock_create_connection
46-
)
47-
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
48-
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
49-
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
50-
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
51-
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
52-
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
53-
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
54-
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
55-
MocketSSLContext.wrap_socket
56-
)
57-
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
58-
"ssl_wrap_socket"
59-
] = MocketSSLContext.wrap_socket
60-
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
61-
MocketSSLContext.wrap_socket
62-
)
63-
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
64-
"ssl_wrap_socket"
65-
] = MocketSSLContext.wrap_socket
66-
urllib3.connection.match_hostname = urllib3.connection.__dict__[
67-
"match_hostname"
68-
] = mock_urllib3_match_hostname
69-
if pyopenssl_override: # pragma: no cover
70-
# Take out the pyopenssl version - use the default implementation
71-
extract_from_urllib3()
40+
def _restore_stdlib_socket() -> None:
41+
import socket
7242

73-
74-
def disable() -> None:
75-
from mocket.mocket import Mocket
7643
from mocket.socket import (
7744
true_create_connection,
7845
true_getaddrinfo,
7946
true_gethostbyname,
8047
true_gethostname,
8148
true_inet_pton,
8249
true_socket,
50+
true_socket_type,
8351
true_socketpair,
84-
true_urllib3_match_hostname,
85-
)
86-
from mocket.ssl.context import (
87-
true_ssl_context,
88-
true_ssl_wrap_socket,
89-
true_urllib3_ssl_wrap_socket,
90-
true_urllib3_wrap_socket,
9152
)
9253

93-
socket.socket = socket.__dict__["socket"] = true_socket
94-
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
95-
socket.SocketType = socket.__dict__["SocketType"] = true_socket
96-
socket.create_connection = socket.__dict__["create_connection"] = (
97-
true_create_connection
98-
)
99-
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
100-
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
101-
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
102-
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
103-
if true_ssl_wrap_socket:
104-
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
105-
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
106-
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
107-
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
108-
true_urllib3_wrap_socket
109-
)
110-
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
111-
"ssl_wrap_socket"
112-
] = true_urllib3_ssl_wrap_socket
113-
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
114-
true_urllib3_ssl_wrap_socket
115-
)
116-
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
117-
"ssl_wrap_socket"
118-
] = true_urllib3_ssl_wrap_socket
119-
urllib3.connection.match_hostname = urllib3.connection.__dict__[
120-
"match_hostname"
121-
] = true_urllib3_match_hostname
122-
Mocket.reset()
123-
if pyopenssl_override: # pragma: no cover
124-
# Put the pyopenssl version back in place
54+
_replace(socket, "create_connection", true_create_connection)
55+
_replace(socket, "getaddrinfo", true_getaddrinfo)
56+
_replace(socket, "gethostbyname", true_gethostbyname)
57+
_replace(socket, "gethostname", true_gethostname)
58+
_replace(socket, "inet_pton", true_inet_pton)
59+
_replace(socket, "socket", true_socket)
60+
_replace(socket, "SocketType", true_socket_type)
61+
_replace(socket, "socketpair", true_socketpair)
62+
63+
64+
def _inject_stdlib_ssl() -> None:
65+
import ssl
66+
67+
from mocket.ssl.context import MocketSSLContext
68+
69+
_replace(ssl, "SSLContext", MocketSSLContext)
70+
71+
if python_version() < Version("3.12.0"):
72+
_replace(ssl, "wrap_socket", MocketSSLContext.wrap_socket)
73+
74+
75+
def _restore_stdlib_ssl() -> None:
76+
import ssl
77+
78+
from mocket.ssl.context import true_ssl_context
79+
80+
_replace(ssl, "SSLContext", true_ssl_context)
81+
82+
if python_version() < Version("3.12.0"):
83+
from mocket.ssl.context import true_wrap_socket
84+
85+
_replace(ssl, "wrap_socket", true_wrap_socket)
86+
87+
88+
def _inject_urllib3() -> None:
89+
import urllib3
90+
91+
from mocket.socket import mock_urllib3_match_hostname
92+
from mocket.ssl.context import MocketSSLContext
93+
94+
_replace(urllib3.util.ssl_, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
95+
_replace(urllib3.util, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
96+
_replace(urllib3.connection, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
97+
_replace(urllib3.connection, "match_hostname", mock_urllib3_match_hostname)
98+
99+
if package_version("urllib3") < Version("2.0.0"):
100+
_replace(urllib3.util.ssl_, "wrap_socket", MocketSSLContext.wrap_socket)
101+
102+
with contextlib.suppress(ImportError):
103+
from urllib3.contrib.pyopenssl import extract_from_urllib3
104+
105+
extract_from_urllib3()
106+
107+
108+
def _restore_urllib3() -> None:
109+
import urllib3
110+
111+
from mocket.socket import true_urllib3_match_hostname
112+
from mocket.ssl.context import true_urllib3_ssl_wrap_socket
113+
114+
_replace(urllib3.connection, "match_hostname", true_urllib3_match_hostname)
115+
_replace(urllib3.util.ssl_, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
116+
_replace(urllib3.util, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
117+
_replace(urllib3.connection, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
118+
119+
if package_version("urllib3") < Version("2.0.0"):
120+
from mocket.ssl.context import true_urllib3_wrap_socket
121+
122+
_replace(urllib3.util.ssl_, "wrap_socket", true_urllib3_wrap_socket)
123+
124+
with contextlib.suppress(ImportError):
125+
from urllib3.contrib.pyopenssl import inject_into_urllib3
126+
125127
inject_into_urllib3()
128+
129+
130+
def enable(
131+
namespace: str | None = None,
132+
truesocket_recording_dir: str | None = None,
133+
) -> None:
134+
_inject_stdlib_socket()
135+
_inject_stdlib_ssl()
136+
_inject_urllib3()
137+
138+
from mocket.mocket import Mocket
139+
140+
Mocket._namespace = namespace
141+
Mocket._truesocket_recording_dir = truesocket_recording_dir
142+
if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
143+
# JSON dumps will be saved here
144+
raise AssertionError
145+
146+
147+
def disable() -> None:
148+
_restore_stdlib_socket()
149+
_restore_stdlib_ssl()
150+
_restore_urllib3()
151+
152+
from mocket.mocket import Mocket
153+
154+
Mocket.reset()

mocket/socket.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from types import TracebackType
1212
from typing import Any, Type
1313

14-
import urllib3.connection
14+
import urllib3
1515
from typing_extensions import Self
1616

1717
from mocket.compat import decode_from_bytes, encode_to_bytes
@@ -33,6 +33,7 @@
3333
true_gethostname = socket.gethostname
3434
true_inet_pton = socket.inet_pton
3535
true_socket = socket.socket
36+
true_socket_type = socket.SocketType
3637
true_socketpair = socket.socketpair
3738
true_urllib3_match_hostname = urllib3.connection.match_hostname
3839

mocket/ssl/context.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,24 @@
11
from __future__ import annotations
22

3-
import contextlib
43
import ssl
54
from typing import Any
65

7-
import urllib3.util.ssl_
6+
import urllib3
7+
from packaging.version import Version
88

99
from mocket.socket import MocketSocket
1010
from mocket.ssl.socket import MocketSSLSocket
11+
from mocket.utils import package_version, python_version
1112

1213
true_ssl_context = ssl.SSLContext
1314

14-
true_ssl_wrap_socket = None
15-
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
16-
true_urllib3_wrap_socket = None
17-
18-
with contextlib.suppress(ImportError):
19-
# from Py3.12 it's only under SSLContext
20-
from ssl import wrap_socket as ssl_wrap_socket
15+
if python_version() < Version("3.12.0"):
16+
true_wrap_socket = ssl.wrap_socket
2117

22-
true_ssl_wrap_socket = ssl_wrap_socket
23-
24-
with contextlib.suppress(ImportError):
25-
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
18+
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
2619

27-
true_urllib3_wrap_socket = urllib3_wrap_socket
20+
if package_version("urllib3") < Version("2.0.0"):
21+
true_urllib3_wrap_socket = urllib3.util.ssl_.wrap_socket
2822

2923

3024
class _MocketSSLContext:

mocket/utils.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
from __future__ import annotations
22

33
import binascii
4+
import importlib.metadata
5+
import platform
46
import ssl
57
from typing import Callable
68

9+
import packaging.version
10+
711
from mocket.compat import decode_from_bytes, encode_to_bytes
812

9-
# NOTE this is here for backwards-compat to keep old import-paths working
10-
from mocket.io import MocketSocketIO as MocketSocketCore
13+
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
1114

12-
# NOTE this is here for backwards-compat to keep old import-paths working
13-
from mocket.mode import MocketMode
1415

15-
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
16+
def python_version() -> packaging.version.Version:
17+
return packaging.version.Version(platform.python_version())
18+
19+
20+
def package_version(package_name: str) -> packaging.version.Version:
21+
version = importlib.metadata.version(package_name)
22+
return packaging.version.parse(version)
1623

1724

1825
def hexdump(binary_string: bytes) -> str:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies = [
3131
"decorator>=4.0.0",
3232
"urllib3>=1.25.3",
3333
"h11",
34+
"packaging>=24.2",
3435
]
3536
dynamic = ["version"]
3637

tests/test_mode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from mocket import Mocketizer, mocketize
55
from mocket.exceptions import StrictMocketException
66
from mocket.mockhttp import Entry, Response
7-
from mocket.utils import MocketMode
7+
from mocket.mode import MocketMode
88

99

1010
@mocketize(strict_mode=True)

0 commit comments

Comments
 (0)