11from __future__ import annotations
22
3+ import contextlib
34import os
4- import socket
5- import ssl
5+ from types import ModuleType
6+ from typing import Any
67
7- import urllib3
8+ from packaging . version import Version
89
9- try : # pragma: no cover
10- from urllib3 .contrib .pyopenssl import extract_from_urllib3 , inject_into_urllib3
10+ from mocket .utils import package_version , python_version
1111
12- pyopenssl_override = True
13- except ImportError :
14- pyopenssl_override = False
1512
13+ def _replace (module : ModuleType , name : str , new_value : Any ) -> None :
14+ module .__dict__ [name ] = new_value
15+
16+
17+ def _inject_stdlib_socket () -> None :
18+ import socket
1619
17- def enable (
18- namespace : str | None = None ,
19- truesocket_recording_dir : str | None = None ,
20- ) -> None :
21- from mocket .mocket import Mocket
2220 from mocket .socket import (
2321 MocketSocket ,
2422 mock_create_connection ,
@@ -27,99 +25,130 @@ def enable(
2725 mock_gethostname ,
2826 mock_inet_pton ,
2927 mock_socketpair ,
30- mock_urllib3_match_hostname ,
3128 )
29+
30+ _replace (socket , "socket" , MocketSocket )
31+ _replace (socket , "SocketType" , MocketSocket )
32+ _replace (socket , "create_connection" , mock_create_connection )
33+ _replace (socket , "getaddrinfo" , mock_getaddrinfo )
34+ _replace (socket , "gethostbyname" , mock_gethostbyname )
35+ _replace (socket , "gethostname" , mock_gethostname )
36+ _replace (socket , "inet_pton" , mock_inet_pton )
37+ _replace (socket , "socketpair" , mock_socketpair )
38+
39+
40+ def _restore_stdlib_socket () -> None :
41+ import socket
42+
43+ from mocket .socket import (
44+ true_socket_create_connection ,
45+ true_socket_getaddrinfo ,
46+ true_socket_gethostbyname ,
47+ true_socket_gethostname ,
48+ true_socket_inet_pton ,
49+ true_socket_socket ,
50+ true_socket_socket_type ,
51+ true_socket_socketpair ,
52+ )
53+
54+ _replace (socket , "SocketType" , true_socket_socket_type )
55+ _replace (socket , "create_connection" , true_socket_create_connection )
56+ _replace (socket , "getaddrinfo" , true_socket_getaddrinfo )
57+ _replace (socket , "gethostbyname" , true_socket_gethostbyname )
58+ _replace (socket , "gethostname" , true_socket_gethostname )
59+ _replace (socket , "inet_pton" , true_socket_inet_pton )
60+ _replace (socket , "socket" , true_socket_socket )
61+ _replace (socket , "socketpair" , true_socket_socketpair )
62+
63+
64+ def _inject_stdlib_ssl () -> None :
65+ import ssl
66+
3267 from mocket .ssl .context import MocketSSLContext
3368
69+ _replace (ssl , "SSLContext" , MocketSSLContext )
70+
71+ if python_version () < Version ("3.12.0" ):
72+ _replace (ssl , "wrap_socket" , MocketSSLContext .wrap_socket )
73+
74+
75+ def _restore_stdlib_ssl () -> None :
76+ import ssl
77+
78+ from mocket .ssl .context import true_ssl_ssl_context
79+
80+ _replace (ssl , "SSLContext" , true_ssl_ssl_context )
81+
82+ if python_version () < Version ("3.12.0" ):
83+ from mocket .ssl .context import true_ssl_wrap_socket
84+
85+ _replace (ssl , "wrap_socket" , true_ssl_wrap_socket )
86+
87+
88+ def _inject_urllib3 () -> None :
89+ import urllib3
90+
91+ from mocket .socket import mock_urllib3_match_hostname
92+ from mocket .ssl .context import MocketSSLContext
93+
94+ _replace (urllib3 .util .ssl_ , "ssl_wrap_socket" , MocketSSLContext .wrap_socket )
95+ _replace (urllib3 .util , "ssl_wrap_socket" , MocketSSLContext .wrap_socket )
96+ _replace (urllib3 .connection , "ssl_wrap_socket" , MocketSSLContext .wrap_socket )
97+ _replace (urllib3 .connection , "match_hostname" , mock_urllib3_match_hostname )
98+
99+ if package_version ("urllib3" ) < Version ("2.0.0" ):
100+ _replace (urllib3 .util .ssl_ , "wrap_socket" , MocketSSLContext .wrap_socket )
101+
102+ with contextlib .suppress (ImportError ):
103+ from urllib3 .contrib .pyopenssl import extract_from_urllib3
104+
105+ extract_from_urllib3 ()
106+
107+
108+ def _restore_urllib3 () -> None :
109+ import urllib3
110+
111+ from mocket .socket import true_urllib3_match_hostname
112+ from mocket .ssl .context import true_urllib3_ssl_wrap_socket
113+
114+ _replace (urllib3 .connection , "match_hostname" , true_urllib3_match_hostname )
115+ _replace (urllib3 .util .ssl_ , "ssl_wrap_socket" , true_urllib3_ssl_wrap_socket )
116+ _replace (urllib3 .util , "ssl_wrap_socket" , true_urllib3_ssl_wrap_socket )
117+ _replace (urllib3 .connection , "ssl_wrap_socket" , true_urllib3_ssl_wrap_socket )
118+
119+ if package_version ("urllib3" ) < Version ("2.0.0" ):
120+ from mocket .ssl .context import true_urllib3_wrap_socket
121+
122+ _replace (urllib3 .util .ssl_ , "wrap_socket" , true_urllib3_wrap_socket )
123+
124+ with contextlib .suppress (ImportError ):
125+ from urllib3 .contrib .pyopenssl import inject_into_urllib3
126+
127+ inject_into_urllib3 ()
128+
129+
130+ def enable (
131+ namespace : str | None = None ,
132+ truesocket_recording_dir : str | None = None ,
133+ ) -> None :
134+ _inject_stdlib_socket ()
135+ _inject_stdlib_ssl ()
136+ _inject_urllib3 ()
137+
138+ from mocket .mocket import Mocket
139+
34140 Mocket ._namespace = namespace
35141 Mocket ._truesocket_recording_dir = truesocket_recording_dir
36-
37142 if truesocket_recording_dir and not os .path .isdir (truesocket_recording_dir ):
38143 # JSON dumps will be saved here
39144 raise AssertionError
40145
41- socket .socket = socket .__dict__ ["socket" ] = MocketSocket
42- socket ._socketobject = socket .__dict__ ["_socketobject" ] = MocketSocket
43- socket .SocketType = socket .__dict__ ["SocketType" ] = MocketSocket
44- socket .create_connection = socket .__dict__ ["create_connection" ] = (
45- mock_create_connection
46- )
47- socket .gethostname = socket .__dict__ ["gethostname" ] = mock_gethostname
48- socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = mock_gethostbyname
49- socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = mock_getaddrinfo
50- socket .socketpair = socket .__dict__ ["socketpair" ] = mock_socketpair
51- ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = MocketSSLContext .wrap_socket
52- ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = MocketSSLContext
53- socket .inet_pton = socket .__dict__ ["inet_pton" ] = mock_inet_pton
54- urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ ["wrap_socket" ] = (
55- MocketSSLContext .wrap_socket
56- )
57- urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
58- "ssl_wrap_socket"
59- ] = MocketSSLContext .wrap_socket
60- urllib3 .util .ssl_wrap_socket = urllib3 .util .__dict__ ["ssl_wrap_socket" ] = (
61- MocketSSLContext .wrap_socket
62- )
63- urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
64- "ssl_wrap_socket"
65- ] = MocketSSLContext .wrap_socket
66- urllib3 .connection .match_hostname = urllib3 .connection .__dict__ [
67- "match_hostname"
68- ] = mock_urllib3_match_hostname
69- if pyopenssl_override : # pragma: no cover
70- # Take out the pyopenssl version - use the default implementation
71- extract_from_urllib3 ()
72-
73146
74147def disable () -> None :
148+ _restore_stdlib_socket ()
149+ _restore_stdlib_ssl ()
150+ _restore_urllib3 ()
151+
75152 from mocket .mocket import Mocket
76- from mocket .socket import (
77- true_create_connection ,
78- true_getaddrinfo ,
79- true_gethostbyname ,
80- true_gethostname ,
81- true_inet_pton ,
82- true_socket ,
83- true_socketpair ,
84- true_urllib3_match_hostname ,
85- )
86- from mocket .ssl .context import (
87- true_ssl_context ,
88- true_ssl_wrap_socket ,
89- true_urllib3_ssl_wrap_socket ,
90- true_urllib3_wrap_socket ,
91- )
92153
93- socket .socket = socket .__dict__ ["socket" ] = true_socket
94- socket ._socketobject = socket .__dict__ ["_socketobject" ] = true_socket
95- socket .SocketType = socket .__dict__ ["SocketType" ] = true_socket
96- socket .create_connection = socket .__dict__ ["create_connection" ] = (
97- true_create_connection
98- )
99- socket .gethostname = socket .__dict__ ["gethostname" ] = true_gethostname
100- socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = true_gethostbyname
101- socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = true_getaddrinfo
102- socket .socketpair = socket .__dict__ ["socketpair" ] = true_socketpair
103- if true_ssl_wrap_socket :
104- ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = true_ssl_wrap_socket
105- ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = true_ssl_context
106- socket .inet_pton = socket .__dict__ ["inet_pton" ] = true_inet_pton
107- urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ ["wrap_socket" ] = (
108- true_urllib3_wrap_socket
109- )
110- urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
111- "ssl_wrap_socket"
112- ] = true_urllib3_ssl_wrap_socket
113- urllib3 .util .ssl_wrap_socket = urllib3 .util .__dict__ ["ssl_wrap_socket" ] = (
114- true_urllib3_ssl_wrap_socket
115- )
116- urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
117- "ssl_wrap_socket"
118- ] = true_urllib3_ssl_wrap_socket
119- urllib3 .connection .match_hostname = urllib3 .connection .__dict__ [
120- "match_hostname"
121- ] = true_urllib3_match_hostname
122154 Mocket .reset ()
123- if pyopenssl_override : # pragma: no cover
124- # Put the pyopenssl version back in place
125- inject_into_urllib3 ()
0 commit comments