Skip to content

Commit 571e8a2

Browse files
committed
refactor: type MocketSocket
1 parent 9f090ec commit 571e8a2

File tree

2 files changed

+72
-32
lines changed

2 files changed

+72
-32
lines changed

mocket/socket.py

Lines changed: 62 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,25 @@
1010
import ssl
1111
from datetime import datetime, timedelta
1212
from json.decoder import JSONDecodeError
13-
from typing import Any
13+
from types import TracebackType
14+
from typing import Any, Type
1415

1516
import urllib3.connection
1617
import urllib3.util.ssl_
18+
from typing_extensions import Self
1719

1820
from mocket.compat import decode_from_bytes, encode_to_bytes
21+
from mocket.entry import MocketEntry
1922
from mocket.io import MocketSocketCore
2023
from mocket.mocket import Mocket
2124
from mocket.mode import MocketMode
25+
from mocket.types import (
26+
Address,
27+
ReadableBuffer,
28+
WriteableBuffer,
29+
_PeerCertRetDictType,
30+
_RetAddress,
31+
)
2232
from mocket.utils import hexdump, hexload
2333

2434
true_create_connection = socket.create_connection
@@ -120,8 +130,13 @@ class MocketSocket:
120130
_io = None
121131

122132
def __init__(
123-
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
124-
):
133+
self,
134+
family: socket.AddressFamily | int = socket.AF_INET,
135+
type: socket.SocketKind | int = socket.SOCK_STREAM,
136+
proto: int = 0,
137+
fileno: int | None = None,
138+
**kwargs: Any,
139+
) -> None:
125140
self.true_socket = true_socket(family, type, proto)
126141
self._buflen = 65536
127142
self._entry = None
@@ -131,63 +146,69 @@ def __init__(
131146
self._truesocket_recording_dir = None
132147
self.kwargs = kwargs
133148

134-
def __str__(self):
149+
def __str__(self) -> str:
135150
return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
136151

137-
def __enter__(self):
152+
def __enter__(self) -> Self:
138153
return self
139154

140-
def __exit__(self, exc_type, exc_val, exc_tb):
155+
def __exit__(
156+
self,
157+
type_: Type[BaseException] | None, # noqa: UP006
158+
value: BaseException | None,
159+
traceback: TracebackType | None,
160+
) -> None:
141161
self.close()
142162

143163
@property
144-
def io(self):
164+
def io(self) -> MocketSocketCore:
145165
if self._io is None:
146166
self._io = MocketSocketCore((self._host, self._port))
147167
return self._io
148168

149-
def fileno(self):
169+
def fileno(self) -> int:
150170
address = (self._host, self._port)
151171
r_fd, _ = Mocket.get_pair(address)
152172
if not r_fd:
153173
r_fd, w_fd = os.pipe()
154174
Mocket.set_pair(address, (r_fd, w_fd))
155175
return r_fd
156176

157-
def gettimeout(self):
177+
def gettimeout(self) -> float | None:
158178
return self.timeout
159179

160-
def setsockopt(self, family, type, proto):
180+
# FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
181+
def setsockopt(self, family: int, type: int, proto: int) -> None:
161182
self.family = family
162183
self.type = type
163184
self.proto = proto
164185

165186
if self.true_socket:
166187
self.true_socket.setsockopt(family, type, proto)
167188

168-
def settimeout(self, timeout):
189+
def settimeout(self, timeout: float | None) -> None:
169190
self.timeout = timeout
170191

171192
@staticmethod
172-
def getsockopt(level, optname, buflen=None):
193+
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
173194
return socket.SOCK_STREAM
174195

175-
def do_handshake(self):
196+
def do_handshake(self) -> None:
176197
self._did_handshake = True
177198

178-
def getpeername(self):
199+
def getpeername(self) -> _RetAddress:
179200
return self._address
180201

181-
def setblocking(self, block):
202+
def setblocking(self, block: bool) -> None:
182203
self.settimeout(None) if block else self.settimeout(0.0)
183204

184-
def getblocking(self):
205+
def getblocking(self) -> bool:
185206
return self.gettimeout() is None
186207

187-
def getsockname(self):
208+
def getsockname(self) -> _RetAddress:
188209
return true_gethostbyname(self._address[0]), self._address[1]
189210

190-
def getpeercert(self, *args, **kwargs):
211+
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
191212
if not (self._host and self._port):
192213
self._address = self._host, self._port = Mocket._address
193214

@@ -207,22 +228,22 @@ def getpeercert(self, *args, **kwargs):
207228
),
208229
}
209230

210-
def unwrap(self):
231+
def unwrap(self) -> MocketSocket:
211232
return self
212233

213-
def write(self, data):
234+
def write(self, data: bytes) -> int | None:
214235
return self.send(encode_to_bytes(data))
215236

216-
def connect(self, address):
237+
def connect(self, address: Address) -> None:
217238
self._address = self._host, self._port = address
218239
Mocket._address = address
219240

220-
def makefile(self, mode="r", bufsize=-1):
241+
def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore:
221242
self._mode = mode
222243
self._bufsize = bufsize
223244
return self.io
224245

225-
def get_entry(self, data):
246+
def get_entry(self, data: bytes) -> MocketEntry | None:
226247
return Mocket.get_entry(self._host, self._port, data)
227248

228249
def sendall(self, data, entry=None, *args, **kwargs):
@@ -241,15 +262,20 @@ def sendall(self, data, entry=None, *args, **kwargs):
241262
self.io.truncate()
242263
self.io.seek(0)
243264

244-
def read(self, buffersize):
265+
def read(self, buffersize: int | None = None) -> bytes:
245266
rv = self.io.read(buffersize)
246267
if rv:
247268
self._sent_non_empty_bytes = True
248269
if self._did_handshake and not self._sent_non_empty_bytes:
249270
raise ssl.SSLWantReadError("The operation did not complete (read)")
250271
return rv
251272

252-
def recv_into(self, buffer, buffersize=None, flags=None):
273+
def recv_into(
274+
self,
275+
buffer: WriteableBuffer,
276+
buffersize: int | None = None,
277+
flags: int | None = None,
278+
) -> int:
253279
if hasattr(buffer, "write"):
254280
return buffer.write(self.read(buffersize))
255281
# buffer is a memoryview
@@ -258,7 +284,7 @@ def recv_into(self, buffer, buffersize=None, flags=None):
258284
buffer[: len(data)] = data
259285
return len(data)
260286

261-
def recv(self, buffersize, flags=None):
287+
def recv(self, buffersize: int, flags: int | None = None) -> bytes:
262288
r_fd, _ = Mocket.get_pair((self._host, self._port))
263289
if r_fd:
264290
return os.read(r_fd, buffersize)
@@ -271,7 +297,7 @@ def recv(self, buffersize, flags=None):
271297
exc.args = (0,)
272298
raise exc
273299

274-
def true_sendall(self, data, *args, **kwargs):
300+
def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
275301
if not MocketMode().is_allowed((self._host, self._port)):
276302
MocketMode.raise_not_allowed()
277303

@@ -359,7 +385,12 @@ def true_sendall(self, data, *args, **kwargs):
359385
# response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
360386
return encoded_response
361387

362-
def send(self, data, *args, **kwargs): # pragma: no cover
388+
def send(
389+
self,
390+
data: ReadableBuffer,
391+
*args: Any,
392+
**kwargs: Any,
393+
) -> int: # pragma: no cover
363394
entry = self.get_entry(data)
364395
if not entry or (entry and self._entry != entry):
365396
kwargs["entry"] = entry
@@ -371,15 +402,15 @@ def send(self, data, *args, **kwargs): # pragma: no cover
371402
self._entry = entry
372403
return len(data)
373404

374-
def close(self):
405+
def close(self) -> None:
375406
if self.true_socket and not self.true_socket._closed:
376407
self.true_socket.close()
377408
self._fd = None
378409

379-
def __getattr__(self, name):
410+
def __getattr__(self, name: str) -> Any:
380411
"""Do nothing catchall function, for methods like shutdown()"""
381412

382-
def do_nothing(*args, **kwargs):
413+
def do_nothing(*args: Any, **kwargs: Any) -> Any:
383414
pass
384415

385416
return do_nothing

mocket/types.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from os import PathLike
4-
from typing import Tuple, Union
4+
from typing import Any, Dict, Tuple, Union
55

66
from typing_extensions import Buffer, TypeAlias
77

@@ -11,3 +11,12 @@
1111
WriteableBuffer: TypeAlias = Buffer
1212
ReadableBuffer: TypeAlias = Buffer
1313
StrOrBytesPath: TypeAlias = Union[str, bytes, PathLike]
14+
15+
# from typeshed/stdlib/_socket.pyi
16+
_Address: TypeAlias = Union[Tuple[Any, ...], str, ReadableBuffer]
17+
_RetAddress: TypeAlias = Any
18+
19+
# from typeshed/stdlib/ssl.pyi
20+
_PCTRTT: TypeAlias = Tuple[Tuple[str, str], ...]
21+
_PCTRTTT: TypeAlias = Tuple[_PCTRTT, ...]
22+
_PeerCertRetDictType: TypeAlias = Dict[str, Union[str, _PCTRTTT, _PCTRTT]]

0 commit comments

Comments
 (0)