Skip to content

Commit 95e00e0

Browse files
committed
refactor: Split ssl-specific parts of MocketSocket into MocketSSLSocket
1 parent e75a257 commit 95e00e0

File tree

5 files changed

+121
-90
lines changed

5 files changed

+121
-90
lines changed

mocket/inject.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,16 @@ def disable() -> None:
8181
true_inet_pton,
8282
true_socket,
8383
true_socketpair,
84-
true_ssl_wrap_socket,
8584
true_urllib3_match_hostname,
86-
true_urllib3_ssl_wrap_socket,
87-
true_urllib3_wrap_socket,
8885
)
8986
from mocket.ssl.context import (
9087
true_ssl_context,
9188
)
89+
from mocket.ssl.socket import (
90+
true_ssl_wrap_socket,
91+
true_urllib3_ssl_wrap_socket,
92+
true_urllib3_wrap_socket,
93+
)
9294

9395
socket.socket = socket.__dict__["socket"] = true_socket
9496
socket._socketobject = socket.__dict__["_socketobject"] = true_socket

mocket/socket.py

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,11 @@
77
import os
88
import select
99
import socket
10-
import ssl
11-
from datetime import datetime, timedelta
1210
from json.decoder import JSONDecodeError
1311
from types import TracebackType
1412
from typing import Any, Type
1513

1614
import urllib3.connection
17-
import urllib3.util.ssl_
1815
from typing_extensions import Self
1916

2017
from mocket.compat import decode_from_bytes, encode_to_bytes
@@ -26,7 +23,6 @@
2623
Address,
2724
ReadableBuffer,
2825
WriteableBuffer,
29-
_PeerCertRetDictType,
3026
_RetAddress,
3127
)
3228
from mocket.utils import hexdump, hexload
@@ -38,22 +34,7 @@
3834
true_inet_pton = socket.inet_pton
3935
true_socket = socket.socket
4036
true_socketpair = socket.socketpair
41-
true_ssl_wrap_socket = None
42-
4337
true_urllib3_match_hostname = urllib3.connection.match_hostname
44-
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
45-
true_urllib3_wrap_socket = None
46-
47-
with contextlib.suppress(ImportError):
48-
# from Py3.12 it's only under SSLContext
49-
from ssl import wrap_socket as ssl_wrap_socket
50-
51-
true_ssl_wrap_socket = ssl_wrap_socket
52-
53-
with contextlib.suppress(ImportError):
54-
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
55-
56-
true_urllib3_wrap_socket = urllib3_wrap_socket
5738

5839

5940
xxh32 = None
@@ -112,9 +93,6 @@ def _hash_request(h, req):
11293

11394

11495
class MocketSocket:
115-
cipher = lambda s: ("ADH", "AES256", "SHA")
116-
compression = lambda s: ssl.OP_NO_COMPRESSION
117-
11896
def __init__(
11997
self,
12098
family: socket.AddressFamily | int = socket.AF_INET,
@@ -129,15 +107,10 @@ def __init__(
129107

130108
self._kwargs = kwargs
131109
self._true_socket = true_socket(family, type, proto)
132-
self._truesocket_recording_dir = None
133110

134111
self._buflen = 65536
135112
self._timeout: float | None = None
136113

137-
self._secure_socket = False
138-
self._did_handshake = False
139-
self._sent_non_empty_bytes = False
140-
141114
self._host = None
142115
self._port = None
143116
self._address = None
@@ -204,9 +177,6 @@ def settimeout(self, timeout: float | None) -> None:
204177
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
205178
return socket.SOCK_STREAM
206179

207-
def do_handshake(self) -> None:
208-
self._did_handshake = True
209-
210180
def getpeername(self) -> _RetAddress:
211181
return self._address
212182

@@ -219,32 +189,6 @@ def getblocking(self) -> bool:
219189
def getsockname(self) -> _RetAddress:
220190
return true_gethostbyname(self._address[0]), self._address[1]
221191

222-
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
223-
if not (self._host and self._port):
224-
self._address = self._host, self._port = Mocket._address
225-
226-
now = datetime.now()
227-
shift = now + timedelta(days=30 * 12)
228-
return {
229-
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
230-
"subjectAltName": (
231-
("DNS", f"*.{self._host}"),
232-
("DNS", self._host),
233-
("DNS", "*"),
234-
),
235-
"subject": (
236-
(("organizationName", f"*.{self._host}"),),
237-
(("organizationalUnitName", "Domain Control Validated"),),
238-
(("commonName", f"*.{self._host}"),),
239-
),
240-
}
241-
242-
def unwrap(self) -> MocketSocket:
243-
return self
244-
245-
def write(self, data: bytes) -> int | None:
246-
return self.send(encode_to_bytes(data))
247-
248192
def connect(self, address: Address) -> None:
249193
self._address = self._host, self._port = address
250194
Mocket._address = address
@@ -271,33 +215,26 @@ def sendall(self, data, entry=None, *args, **kwargs):
271215
self.io.truncate()
272216
self.io.seek(0)
273217

274-
def read(self, buffersize: int | None = None) -> bytes:
275-
rv = self.io.read(buffersize)
276-
if rv:
277-
self._sent_non_empty_bytes = True
278-
if self._did_handshake and not self._sent_non_empty_bytes:
279-
raise ssl.SSLWantReadError("The operation did not complete (read)")
280-
return rv
281-
282218
def recv_into(
283219
self,
284220
buffer: WriteableBuffer,
285221
buffersize: int | None = None,
286222
flags: int | None = None,
287223
) -> int:
224+
data = self.recv(buffersize)
225+
288226
if hasattr(buffer, "write"):
289-
return buffer.write(self.read(buffersize))
227+
return buffer.write(data)
228+
290229
# buffer is a memoryview
291-
data = self.read(buffersize)
292-
if data:
293-
buffer[: len(data)] = data
230+
buffer[: len(data)] = data
294231
return len(data)
295232

296233
def recv(self, buffersize: int, flags: int | None = None) -> bytes:
297234
r_fd, _ = Mocket.get_pair((self._host, self._port))
298235
if r_fd:
299236
return os.read(r_fd, buffersize)
300-
data = self.read(buffersize)
237+
data = self.io.read(buffersize)
301238
if data:
302239
return data
303240
# used by Redis mock
@@ -357,12 +294,6 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
357294
host, port = self._host, self._port
358295
host = true_gethostbyname(host)
359296

360-
if isinstance(self._true_socket, true_socket) and self._secure_socket:
361-
self._true_socket = true_urllib3_ssl_wrap_socket(
362-
self._true_socket,
363-
**self._kwargs,
364-
)
365-
366297
with contextlib.suppress(OSError, ValueError):
367298
# already connected
368299
self._true_socket.connect((host, port))
@@ -400,6 +331,7 @@ def send(
400331
*args: Any,
401332
**kwargs: Any,
402333
) -> int: # pragma: no cover
334+
data = encode_to_bytes(data)
403335
entry = self.get_entry(data)
404336
if not entry or (entry and self._entry != entry):
405337
kwargs["entry"] = entry

mocket/ssl/context.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
from mocket.socket import MocketSocket
8+
from mocket.ssl.socket import MocketSSLSocket
89
from mocket.types import ReadableBuffer, StrOrBytesPath
910

1011
true_ssl_context = ssl.SSLContext
@@ -81,18 +82,16 @@ def check_hostname(self, _: bool) -> None:
8182
self._check_hostname = False
8283

8384
@staticmethod
84-
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket:
85-
sock._kwargs = kwargs
86-
sock._secure_socket = True
87-
return sock
85+
def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket:
86+
return MocketSSLSocket._create(sock, *args, **kwargs)
8887

8988
@staticmethod
9089
def wrap_bio(
9190
incoming: Any, # _ssl.MemoryBIO
9291
outgoing: Any, # _ssl.MemoryBIO
9392
server_side: bool = False,
9493
server_hostname: str | bytes | None = None,
95-
) -> MocketSocket:
96-
ssl_obj = MocketSocket()
94+
) -> MocketSSLSocket:
95+
ssl_obj = MocketSSLSocket()
9796
ssl_obj._host = server_hostname
9897
return ssl_obj

mocket/ssl/socket.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from __future__ import annotations
2+
3+
import contextlib
4+
import ssl
5+
from datetime import datetime, timedelta
6+
from typing import Any
7+
8+
import urllib3.util.ssl_
9+
10+
from mocket.mocket import Mocket
11+
from mocket.socket import MocketSocket
12+
from mocket.types import _PeerCertRetDictType
13+
14+
true_ssl_wrap_socket = None
15+
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
16+
true_urllib3_wrap_socket = None
17+
18+
with contextlib.suppress(ImportError):
19+
# from Py3.12 it's only under SSLContext
20+
from ssl import wrap_socket as ssl_wrap_socket
21+
22+
true_ssl_wrap_socket = ssl_wrap_socket
23+
24+
with contextlib.suppress(ImportError):
25+
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
26+
27+
true_urllib3_wrap_socket = urllib3_wrap_socket
28+
29+
30+
class MocketSSLSocket(MocketSocket):
31+
def __init__(self, *args: Any, **kwargs: Any) -> None:
32+
super().__init__(*args, **kwargs)
33+
34+
self._did_handshake = False
35+
36+
def read(self, buffersize: int | None = None) -> bytes:
37+
rv = self.io.read(buffersize)
38+
if rv:
39+
self._sent_non_empty_bytes = True
40+
if self._did_handshake and not self._sent_non_empty_bytes:
41+
raise ssl.SSLWantReadError("The operation did not complete (read)")
42+
return rv
43+
44+
def write(self, data: bytes) -> int | None:
45+
return self.send(data)
46+
47+
def do_handshake(self) -> None:
48+
self._did_handshake = True
49+
50+
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
51+
if not (self._host and self._port):
52+
self._address = self._host, self._port = Mocket._address
53+
54+
now = datetime.now()
55+
shift = now + timedelta(days=30 * 12)
56+
return {
57+
"notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
58+
"subjectAltName": (
59+
("DNS", f"*.{self._host}"),
60+
("DNS", self._host),
61+
("DNS", "*"),
62+
),
63+
"subject": (
64+
(("organizationName", f"*.{self._host}"),),
65+
(("organizationalUnitName", "Domain Control Validated"),),
66+
(("commonName", f"*.{self._host}"),),
67+
),
68+
}
69+
70+
def ciper(self) -> tuple[str, str, str]:
71+
return ("ADH", "AES256", "SHA")
72+
73+
def compression(self) -> str | None:
74+
return ssl.OP_NO_COMPRESSION
75+
76+
def unwrap(self) -> MocketSocket:
77+
return self
78+
79+
@classmethod
80+
def _create(cls, sock: MocketSocket, *args, **kwargs) -> MocketSSLSocket:
81+
ssl_socket = MocketSSLSocket()
82+
83+
ssl_socket._kwargs = kwargs
84+
ssl_socket._true_socket = true_urllib3_ssl_wrap_socket(
85+
sock._true_socket,
86+
**kwargs,
87+
)
88+
89+
ssl_socket._timeout = sock._timeout
90+
91+
ssl_socket._host = sock._host
92+
ssl_socket._port = sock._port
93+
ssl_socket._address = sock._address
94+
95+
ssl_socket._io = sock._io
96+
ssl_socket._entry = sock._entry
97+
98+
return ssl_socket

tests/test_http.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,12 @@ def test_sockets(self):
359359
sock = socket.socket(address[0], address[1], address[2])
360360

361361
sock.connect(address[-1])
362-
sock.write(f"{method} {path} HTTP/1.0\r\n")
363-
sock.write(f"Host: {host}\r\n")
364-
sock.write("Content-Type: application/json\r\n")
365-
sock.write("Content-Length: %d\r\n" % len(data))
366-
sock.write("Connection: close\r\n\r\n")
367-
sock.write(data)
362+
sock.send(f"{method} {path} HTTP/1.0\r\n".encode())
363+
sock.send(f"Host: {host}\r\n".encode())
364+
sock.send(b"Content-Type: application/json\r\n")
365+
sock.send(b"Content-Length: %d\r\n" % len(data))
366+
sock.send(b"Connection: close\r\n\r\n")
367+
sock.send(data.encode())
368368
sock.close()
369369

370370
# Proof that worked.

0 commit comments

Comments
 (0)