1010import ssl
1111from datetime import datetime , timedelta
1212from json .decoder import JSONDecodeError
13- from typing import Any
13+ from types import TracebackType
14+ from typing import Any , Type
1415
1516import urllib3 .connection
1617import urllib3 .util .ssl_
18+ from typing_extensions import Self
1719
1820from mocket .compat import decode_from_bytes , encode_to_bytes
21+ from mocket .entry import MocketEntry
1922from mocket .io import MocketSocketCore
2023from mocket .mocket import Mocket
2124from mocket .mode import MocketMode
25+ from mocket .types import (
26+ Address ,
27+ ReadableBuffer ,
28+ WriteableBuffer ,
29+ _PeerCertRetDictType ,
30+ _RetAddress ,
31+ )
2232from mocket .utils import hexdump , hexload
2333
2434true_create_connection = socket .create_connection
@@ -120,8 +130,13 @@ class MocketSocket:
120130 _io = None
121131
122132 def __init__ (
123- self , family = socket .AF_INET , type = socket .SOCK_STREAM , proto = 0 , ** kwargs
124- ):
133+ self ,
134+ family : socket .AddressFamily | int = socket .AF_INET ,
135+ type : socket .SocketKind | int = socket .SOCK_STREAM ,
136+ proto : int = 0 ,
137+ fileno : int | None = None ,
138+ ** kwargs : Any ,
139+ ) -> None :
125140 self .true_socket = true_socket (family , type , proto )
126141 self ._buflen = 65536
127142 self ._entry = None
@@ -131,63 +146,69 @@ def __init__(
131146 self ._truesocket_recording_dir = None
132147 self .kwargs = kwargs
133148
134- def __str__ (self ):
149+ def __str__ (self ) -> str :
135150 return f"({ self .__class__ .__name__ } )(family={ self .family } type={ self .type } protocol={ self .proto } )"
136151
137- def __enter__ (self ):
152+ def __enter__ (self ) -> Self :
138153 return self
139154
140- def __exit__ (self , exc_type , exc_val , exc_tb ):
155+ def __exit__ (
156+ self ,
157+ type_ : Type [BaseException ] | None , # noqa: UP006
158+ value : BaseException | None ,
159+ traceback : TracebackType | None ,
160+ ) -> None :
141161 self .close ()
142162
143163 @property
144- def io (self ):
164+ def io (self ) -> MocketSocketCore :
145165 if self ._io is None :
146166 self ._io = MocketSocketCore ((self ._host , self ._port ))
147167 return self ._io
148168
149- def fileno (self ):
169+ def fileno (self ) -> int :
150170 address = (self ._host , self ._port )
151171 r_fd , _ = Mocket .get_pair (address )
152172 if not r_fd :
153173 r_fd , w_fd = os .pipe ()
154174 Mocket .set_pair (address , (r_fd , w_fd ))
155175 return r_fd
156176
157- def gettimeout (self ):
177+ def gettimeout (self ) -> float | None :
158178 return self .timeout
159179
160- def setsockopt (self , family , type , proto ):
180+ # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
181+ def setsockopt (self , family : int , type : int , proto : int ) -> None :
161182 self .family = family
162183 self .type = type
163184 self .proto = proto
164185
165186 if self .true_socket :
166187 self .true_socket .setsockopt (family , type , proto )
167188
168- def settimeout (self , timeout ) :
189+ def settimeout (self , timeout : float | None ) -> None :
169190 self .timeout = timeout
170191
171192 @staticmethod
172- def getsockopt (level , optname , buflen = None ):
193+ def getsockopt (level : int , optname : int , buflen : int | None = None ) -> int :
173194 return socket .SOCK_STREAM
174195
175- def do_handshake (self ):
196+ def do_handshake (self ) -> None :
176197 self ._did_handshake = True
177198
178- def getpeername (self ):
199+ def getpeername (self ) -> _RetAddress :
179200 return self ._address
180201
181- def setblocking (self , block ) :
202+ def setblocking (self , block : bool ) -> None :
182203 self .settimeout (None ) if block else self .settimeout (0.0 )
183204
184- def getblocking (self ):
205+ def getblocking (self ) -> bool :
185206 return self .gettimeout () is None
186207
187- def getsockname (self ):
208+ def getsockname (self ) -> _RetAddress :
188209 return true_gethostbyname (self ._address [0 ]), self ._address [1 ]
189210
190- def getpeercert (self , * args , ** kwargs ) :
211+ def getpeercert (self , binary_form : bool = False ) -> _PeerCertRetDictType :
191212 if not (self ._host and self ._port ):
192213 self ._address = self ._host , self ._port = Mocket ._address
193214
@@ -207,22 +228,22 @@ def getpeercert(self, *args, **kwargs):
207228 ),
208229 }
209230
210- def unwrap (self ):
231+ def unwrap (self ) -> MocketSocket :
211232 return self
212233
213- def write (self , data ) :
234+ def write (self , data : bytes ) -> int | None :
214235 return self .send (encode_to_bytes (data ))
215236
216- def connect (self , address ) :
237+ def connect (self , address : Address ) -> None :
217238 self ._address = self ._host , self ._port = address
218239 Mocket ._address = address
219240
220- def makefile (self , mode = "r" , bufsize = - 1 ):
241+ def makefile (self , mode : str = "r" , bufsize : int = - 1 ) -> MocketSocketCore :
221242 self ._mode = mode
222243 self ._bufsize = bufsize
223244 return self .io
224245
225- def get_entry (self , data ) :
246+ def get_entry (self , data : bytes ) -> MocketEntry | None :
226247 return Mocket .get_entry (self ._host , self ._port , data )
227248
228249 def sendall (self , data , entry = None , * args , ** kwargs ):
@@ -241,15 +262,20 @@ def sendall(self, data, entry=None, *args, **kwargs):
241262 self .io .truncate ()
242263 self .io .seek (0 )
243264
244- def read (self , buffersize ) :
265+ def read (self , buffersize : int | None = None ) -> bytes :
245266 rv = self .io .read (buffersize )
246267 if rv :
247268 self ._sent_non_empty_bytes = True
248269 if self ._did_handshake and not self ._sent_non_empty_bytes :
249270 raise ssl .SSLWantReadError ("The operation did not complete (read)" )
250271 return rv
251272
252- def recv_into (self , buffer , buffersize = None , flags = None ):
273+ def recv_into (
274+ self ,
275+ buffer : WriteableBuffer ,
276+ buffersize : int | None = None ,
277+ flags : int | None = None ,
278+ ) -> int :
253279 if hasattr (buffer , "write" ):
254280 return buffer .write (self .read (buffersize ))
255281 # buffer is a memoryview
@@ -258,7 +284,7 @@ def recv_into(self, buffer, buffersize=None, flags=None):
258284 buffer [: len (data )] = data
259285 return len (data )
260286
261- def recv (self , buffersize , flags = None ):
287+ def recv (self , buffersize : int , flags : int | None = None ) -> bytes :
262288 r_fd , _ = Mocket .get_pair ((self ._host , self ._port ))
263289 if r_fd :
264290 return os .read (r_fd , buffersize )
@@ -271,7 +297,7 @@ def recv(self, buffersize, flags=None):
271297 exc .args = (0 ,)
272298 raise exc
273299
274- def true_sendall (self , data , * args , ** kwargs ) :
300+ def true_sendall (self , data : ReadableBuffer , * args : Any , ** kwargs : Any ) -> int :
275301 if not MocketMode ().is_allowed ((self ._host , self ._port )):
276302 MocketMode .raise_not_allowed ()
277303
@@ -359,7 +385,12 @@ def true_sendall(self, data, *args, **kwargs):
359385 # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
360386 return encoded_response
361387
362- def send (self , data , * args , ** kwargs ): # pragma: no cover
388+ def send (
389+ self ,
390+ data : ReadableBuffer ,
391+ * args : Any ,
392+ ** kwargs : Any ,
393+ ) -> int : # pragma: no cover
363394 entry = self .get_entry (data )
364395 if not entry or (entry and self ._entry != entry ):
365396 kwargs ["entry" ] = entry
@@ -371,15 +402,15 @@ def send(self, data, *args, **kwargs): # pragma: no cover
371402 self ._entry = entry
372403 return len (data )
373404
374- def close (self ):
405+ def close (self ) -> None :
375406 if self .true_socket and not self .true_socket ._closed :
376407 self .true_socket .close ()
377408 self ._fd = None
378409
379- def __getattr__ (self , name ) :
410+ def __getattr__ (self , name : str ) -> Any :
380411 """Do nothing catchall function, for methods like shutdown()"""
381412
382- def do_nothing (* args , ** kwargs ) :
413+ def do_nothing (* args : Any , ** kwargs : Any ) -> Any :
383414 pass
384415
385416 return do_nothing
0 commit comments