11from __future__ import annotations
22
3+ import contextlib
34import os
45import socket
56import ssl
7+ from types import ModuleType
8+ from typing import Any
69
710import urllib3
811
9- try : # pragma: no cover
10- from urllib3 .contrib .pyopenssl import extract_from_urllib3 , inject_into_urllib3
12+ _patches_restore : dict [tuple [ModuleType , str ], Any ] = {}
1113
12- pyopenssl_override = True
13- except ImportError :
14- pyopenssl_override = False
14+
15+ def _patch (module : ModuleType , name : str , patched_value : Any ) -> None :
16+ with contextlib .suppress (KeyError ):
17+ original_value , module .__dict__ [name ] = module .__dict__ [name ], patched_value
18+ _patches_restore [(module , name )] = original_value
19+
20+
21+ def _restore (module : ModuleType , name : str ) -> None :
22+ if original_value := _patches_restore .pop ((module , name )):
23+ module .__dict__ [name ] = original_value
1524
1625
1726def enable (
1827 namespace : str | None = None ,
1928 truesocket_recording_dir : str | None = None ,
2029) -> None :
21- from mocket .mocket import Mocket
2230 from mocket .socket import (
2331 MocketSocket ,
2432 mock_create_connection ,
@@ -27,99 +35,62 @@ def enable(
2735 mock_gethostname ,
2836 mock_inet_pton ,
2937 mock_socketpair ,
30- mock_urllib3_match_hostname ,
3138 )
3239 from mocket .ssl .context import MocketSSLContext
40+ from mocket .urllib3 import (
41+ mock_match_hostname as mock_urllib3_match_hostname ,
42+ )
43+ from mocket .urllib3 import (
44+ mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket ,
45+ )
46+
47+ patches = {
48+ # stdlib: socket
49+ (socket , "socket" ): MocketSocket ,
50+ (socket , "create_connection" ): mock_create_connection ,
51+ (socket , "getaddrinfo" ): mock_getaddrinfo ,
52+ (socket , "gethostbyname" ): mock_gethostbyname ,
53+ (socket , "gethostname" ): mock_gethostname ,
54+ (socket , "inet_pton" ): mock_inet_pton ,
55+ (socket , "SocketType" ): MocketSocket ,
56+ (socket , "socketpair" ): mock_socketpair ,
57+ # stdlib: ssl
58+ (ssl , "SSLContext" ): MocketSSLContext ,
59+ (ssl , "wrap_socket" ): MocketSSLContext .wrap_socket , # python < 3.12.0
60+ # urllib3
61+ (urllib3 .connection , "match_hostname" ): mock_urllib3_match_hostname ,
62+ (urllib3 .connection , "ssl_wrap_socket" ): mock_urllib3_ssl_wrap_socket ,
63+ (urllib3 .util , "ssl_wrap_socket" ): mock_urllib3_ssl_wrap_socket ,
64+ (urllib3 .util .ssl_ , "ssl_wrap_socket" ): mock_urllib3_ssl_wrap_socket ,
65+ (urllib3 .util .ssl_ , "wrap_socket" ): mock_urllib3_ssl_wrap_socket , # urllib3 < 2
66+ }
67+
68+ for (module , name ), new_value in patches .items ():
69+ _patch (module , name , new_value )
70+
71+ with contextlib .suppress (ImportError ):
72+ from urllib3 .contrib .pyopenssl import extract_from_urllib3
73+
74+ extract_from_urllib3 ()
75+
76+ from mocket .mocket import Mocket
3377
3478 Mocket ._namespace = namespace
3579 Mocket ._truesocket_recording_dir = truesocket_recording_dir
36-
3780 if truesocket_recording_dir and not os .path .isdir (truesocket_recording_dir ):
3881 # JSON dumps will be saved here
3982 raise AssertionError
4083
41- socket .socket = socket .__dict__ ["socket" ] = MocketSocket
42- socket ._socketobject = socket .__dict__ ["_socketobject" ] = MocketSocket
43- socket .SocketType = socket .__dict__ ["SocketType" ] = MocketSocket
44- socket .create_connection = socket .__dict__ ["create_connection" ] = (
45- mock_create_connection
46- )
47- socket .gethostname = socket .__dict__ ["gethostname" ] = mock_gethostname
48- socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = mock_gethostbyname
49- socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = mock_getaddrinfo
50- socket .socketpair = socket .__dict__ ["socketpair" ] = mock_socketpair
51- ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = MocketSSLContext .wrap_socket
52- ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = MocketSSLContext
53- socket .inet_pton = socket .__dict__ ["inet_pton" ] = mock_inet_pton
54- urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ ["wrap_socket" ] = (
55- MocketSSLContext .wrap_socket
56- )
57- urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
58- "ssl_wrap_socket"
59- ] = MocketSSLContext .wrap_socket
60- urllib3 .util .ssl_wrap_socket = urllib3 .util .__dict__ ["ssl_wrap_socket" ] = (
61- MocketSSLContext .wrap_socket
62- )
63- urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
64- "ssl_wrap_socket"
65- ] = MocketSSLContext .wrap_socket
66- urllib3 .connection .match_hostname = urllib3 .connection .__dict__ [
67- "match_hostname"
68- ] = mock_urllib3_match_hostname
69- if pyopenssl_override : # pragma: no cover
70- # Take out the pyopenssl version - use the default implementation
71- extract_from_urllib3 ()
72-
7384
7485def disable () -> None :
86+ for module , name in list (_patches_restore .keys ()):
87+ _restore (module , name )
88+
89+ with contextlib .suppress (ImportError ):
90+ from urllib3 .contrib .pyopenssl import inject_into_urllib3
91+
92+ inject_into_urllib3 ()
93+
7594 from mocket .mocket import Mocket
76- from mocket .socket import (
77- true_create_connection ,
78- true_getaddrinfo ,
79- true_gethostbyname ,
80- true_gethostname ,
81- true_inet_pton ,
82- true_socket ,
83- true_socketpair ,
84- true_urllib3_match_hostname ,
85- )
86- from mocket .ssl .context import (
87- true_ssl_context ,
88- true_ssl_wrap_socket ,
89- true_urllib3_ssl_wrap_socket ,
90- true_urllib3_wrap_socket ,
91- )
9295
93- socket .socket = socket .__dict__ ["socket" ] = true_socket
94- socket ._socketobject = socket .__dict__ ["_socketobject" ] = true_socket
95- socket .SocketType = socket .__dict__ ["SocketType" ] = true_socket
96- socket .create_connection = socket .__dict__ ["create_connection" ] = (
97- true_create_connection
98- )
99- socket .gethostname = socket .__dict__ ["gethostname" ] = true_gethostname
100- socket .gethostbyname = socket .__dict__ ["gethostbyname" ] = true_gethostbyname
101- socket .getaddrinfo = socket .__dict__ ["getaddrinfo" ] = true_getaddrinfo
102- socket .socketpair = socket .__dict__ ["socketpair" ] = true_socketpair
103- if true_ssl_wrap_socket :
104- ssl .wrap_socket = ssl .__dict__ ["wrap_socket" ] = true_ssl_wrap_socket
105- ssl .SSLContext = ssl .__dict__ ["SSLContext" ] = true_ssl_context
106- socket .inet_pton = socket .__dict__ ["inet_pton" ] = true_inet_pton
107- urllib3 .util .ssl_ .wrap_socket = urllib3 .util .ssl_ .__dict__ ["wrap_socket" ] = (
108- true_urllib3_wrap_socket
109- )
110- urllib3 .util .ssl_ .ssl_wrap_socket = urllib3 .util .ssl_ .__dict__ [
111- "ssl_wrap_socket"
112- ] = true_urllib3_ssl_wrap_socket
113- urllib3 .util .ssl_wrap_socket = urllib3 .util .__dict__ ["ssl_wrap_socket" ] = (
114- true_urllib3_ssl_wrap_socket
115- )
116- urllib3 .connection .ssl_wrap_socket = urllib3 .connection .__dict__ [
117- "ssl_wrap_socket"
118- ] = true_urllib3_ssl_wrap_socket
119- urllib3 .connection .match_hostname = urllib3 .connection .__dict__ [
120- "match_hostname"
121- ] = true_urllib3_match_hostname
12296 Mocket .reset ()
123- if pyopenssl_override : # pragma: no cover
124- # Put the pyopenssl version back in place
125- inject_into_urllib3 ()
0 commit comments