diff --git a/.coveragerc b/.coveragerc index d84a6fc8..c79b5374 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,10 @@ include = omit = kazoo/tests/* kazoo/testing/* + +# Note - this is a copy of the default exclusions from coverage 7.10.1 +[report] +exclude_lines = + #\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER) + ^\s*(((async )?def .*?)?\)(\s*->.*?)?:\s*)?\.\.\.\s*(#|$) + if (typing\.)?TYPE_CHECKING: diff --git a/.flake8 b/.flake8 index ba8a3d67..75277d34 100644 --- a/.flake8 +++ b/.flake8 @@ -1,14 +1,28 @@ [flake8] builtins = _ exclude = + docs/conf.py, .git, __pycache__, - .venv/,venv/, - .tox/, - build/,dist/,*egg, - docs/conf.py, - zookeeper/ -# See black's documentation for E203 + .venv/, + venv*/, + .tox*/, + build/, + dist/, + *egg, + zookeeper/, + max-line-length = 79 -extend-ignore = BLK100,E203 + +# See black's documentation for E203 +# +# I am not sure what version of flake8 hound is using but it gives a lot of +# undefined names for comments (F821) and redefinition of unused variables +# (F811). +# +# I've also had to supress F401 because you generally also need to specify +# type: ignore and something gets very upset when you have both in the same +# comment. +# +extend-ignore = BLK100,E203,F401,F811,F821 diff --git a/.gitignore b/.gitignore index 1c2b4b24..cc13a345 100644 --- a/.gitignore +++ b/.gitignore @@ -29,14 +29,15 @@ zookeeper/ .idea .project .pydevproject -.tox +.tox*/ venv*/ /.settings /.metadata +__pycache__/ !.gitignore !.git-blame-ignore-revs -.vscode/settings.json +.vscode/ .*_cache/ coverage.xml diff --git a/constraints.txt b/constraints.txt index 51c48a10..58be5adb 100644 --- a/constraints.txt +++ b/constraints.txt @@ -1,8 +1,14 @@ # Consistent testing environment. +# requirements.txt +eventlet>=0.17.1 ; implementation_name!='pypy' +gevent>=1.2 ; implementation_name!='pypy' + +# requirements-dev.txt black==22.10.0 coverage==6.3.2; python_version=="3.8" coverage==7.10.7; python_version > "3.8" flake8==5.0.2 +mypy==1.14.1 objgraph==3.5.0 pytest==6.2.5; python_version=="3.8" pytest==8.4.2; python_version > "3.8" diff --git a/docs/testing.rst b/docs/testing.rst index c98e35a8..b6d057be 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -37,10 +37,10 @@ Example: from kazoo.testing import KazooTestHarness class MyTest(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: self.setup_zookeeper() - def tearDown(self): + def tearDown(self)-> None: self.teardown_zookeeper() def testmycode(self): diff --git a/kazoo/client.py b/kazoo/client.py index 3029d1c5..893a6ca8 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -1,4 +1,7 @@ """Kazoo Zookeeper Client""" + +from __future__ import annotations + from collections import defaultdict, deque from functools import partial import inspect @@ -6,6 +9,22 @@ from os.path import split import re import warnings +from types import TracebackType +from typing import ( + cast, + overload, + Any, + Callable, + Deque, + Iterable, + Literal, + Optional, + Sequence, + Set, + TypedDict, + TYPE_CHECKING, +) +from typing_extensions import ParamSpec, Unpack, deprecated from kazoo.exceptions import ( AuthFailedError, @@ -63,6 +82,10 @@ from kazoo.recipe.queue import Queue, LockingQueue from kazoo.recipe.watchers import ChildrenWatch, DataWatch +if TYPE_CHECKING: + from kazoo.interfaces import Event, IAsyncResult, IHandler + from kazoo.protocol.states import ZnodeStat + CLOSED_STATES = ( KeeperState.EXPIRED_SESSION, @@ -89,6 +112,35 @@ ) +class LegacyRetryParams(TypedDict, total=False): + max_retries: int + retry_delay: float + retry_backoff: int + retry_max_delay: float + + +# Signature for functions called by add_listener +ListenerFunc = Callable[[KazooState], Optional[bool]] + +# Signatures for get, get_children and exists watches +WatchFunc = Callable[[WatchedEvent], Optional[bool]] + +GenericArgs = ParamSpec("GenericArgs") + + +# Kazoo retry parameters +class KazooRetryParams(TypedDict, total=False): + max_tries: int + delay: float + backoff: int + max_jitter: float + max_delay: float + ignore_expire: bool + sleep_func: Callable[[float], None] + deadline: float + interrupt: Callable[[], bool] + + class KazooClient(object): """An Apache Zookeeper Python client supporting alternate callback handlers and high-level functionality. @@ -100,29 +152,86 @@ class KazooClient(object): """ + @overload + def __init__( + self, + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: Iterable[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + ) -> None: + ... + + # FIXME This should be deprecated then killed + @overload + @deprecated( + "Passing retry configuration parameters directly to the client" + " is deprecated, please pass a configured retry object (using param" + " connection_retry or command_retry)" + ) + def __init__( + self, + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: Iterable[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + **kwargs: Unpack[LegacyRetryParams], + ) -> None: + ... + def __init__( self, - hosts="127.0.0.1:2181", - timeout=10.0, - client_id=None, - handler=None, - default_acl=None, - auth_data=None, - sasl_options=None, - read_only=None, - randomize_hosts=True, - connection_retry=None, - command_retry=None, - logger=None, - keyfile=None, - keyfile_password=None, - certfile=None, - ca=None, - use_ssl=False, - verify_certs=True, - check_hostname=False, - **kwargs, - ): + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: Iterable[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + **kwargs: Unpack[LegacyRetryParams], + ) -> None: """Create a :class:`KazooClient` instance. All time arguments are in seconds. @@ -231,11 +340,18 @@ def __init__( "not the class: %s" % self.handler ) - self.auth_data = auth_data if auth_data else set([]) + self.auth_data = set(auth_data if auth_data else []) self.default_acl = default_acl self.randomize_hosts = randomize_hosts - self.hosts = None - self.chroot = None + # FIXME Note: hosts and chroot are set by set_hosts, which also checks + # for chroot changes at runtime, so we initialize them to None here to + # avoid confusion with the empty string that set_hosts would set them + # to. This is massively hacky as set_hosts is only called from here + # anyway, but I want to make this change minimally invasive. + # we should really do self.hosts, self.chroot = self.set_hosts(hosts) + # and have set_hosts return the hosts and chroot + self.hosts: list[tuple[str, int]] = None # type: ignore[assignment] + self.chroot: str = None # type: ignore[assignment] self.set_hosts(hosts) self.use_ssl = use_ssl @@ -247,11 +363,15 @@ def __init__( self.ca = ca # Curator like simplified state tracking, and listeners for # state transitions - self._state = KeeperState.CLOSED - self.state = KazooState.LOST - self.state_listeners = set() - self._child_watchers = defaultdict(set) - self._data_watchers = defaultdict(set) + self._state: KeeperState = KeeperState.CLOSED + self.state: KazooState = KazooState.LOST + self.state_listeners: set[ListenerFunc] = set() + self._child_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) + self._data_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) self._reset() self.read_only = read_only @@ -272,7 +392,14 @@ def __init__( self._stopped.set() self._writer_stopped.set() - self.retry = self._conn_retry = None + # FIXME This is kind of gross but we need to set these to something so + # that the type checker will understand that they are set by the time + # they are used and that they have the right type. + # We would do better to use a few variables/functions instead of + # overloading self.retry but this is a bit less invasive to the code + # and the type checker can understand it with a few hacks + self.retry: KazooRetry = None # type: ignore[assignment] + self._conn_retry: KazooRetry = None # type: ignore[assignment] if type(connection_retry) is dict: self._conn_retry = KazooRetry(**connection_retry) @@ -299,7 +426,11 @@ def __init__( ) if self.retry is None or self._conn_retry is None: - old_retry_keys = dict(_RETRY_COMPAT_DEFAULTS) + # Note: because of the hacks at line 280, mypy thinks this is + # unreachable + old_retry_keys = dict( # type: ignore[unreachable] + _RETRY_COMPAT_DEFAULTS + ) for key in old_retry_keys: try: old_retry_keys[key] = kwargs.pop(key) @@ -320,11 +451,13 @@ def __init__( if self._conn_retry is None: self._conn_retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) if self.retry is None: self.retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) # Managing legacy SASL options @@ -372,10 +505,21 @@ def __init__( # to avoid shared retry counts self._retry = self.retry - def _retry(*args, **kwargs): - return self._retry.copy()(*args, **kwargs) - - self.retry = _retry + def _retry( + func: Callable[GenericArgs, KazooRetry.RETRY_RETURN], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, + ) -> KazooRetry.RETRY_RETURN: + return self._retry.copy()(func, *args, **kwargs) + + # FIXME + # (expression has type "Callable[[VarArg(Any), KwArg(Any)], Any]", + # variable has type "KazooRetry") so basically self.retry needs to be + # set to that and then the type checker will understand that + # self.retry.copy() is a valid call. This is just a mess and needs the + # code rearranging to be more mypy friendly but this is the least + # invasive way to do it for now + self.retry = _retry # type: ignore[assignment] self.Barrier = partial(Barrier, self) self.Counter = partial(Counter, self) @@ -402,18 +546,18 @@ def _retry(*args, **kwargs): % (kwargs.keys(),) ) - def _reset(self): + def _reset(self) -> None: """Resets a variety of client states for a new connection.""" - self._queue = deque() - self._pending = deque() + self._queue: Deque[tuple[Any, IAsyncResult]] = deque() + self._pending: Deque[tuple[Any, IAsyncResult, int]] = deque() self._reset_watchers() self._reset_session() self.last_zxid = 0 - self._protocol_version = None + self._protocol_version: int | None = None - def _reset_watchers(self): - watchers = [] + def _reset_watchers(self) -> None: + watchers: list[WatchFunc] = [] for child_watchers in self._child_watchers.values(): watchers.extend(child_watchers) @@ -427,12 +571,12 @@ def _reset_watchers(self): for watch in watchers: self.handler.dispatch_callback(Callback("watch", watch, (ev,))) - def _reset_session(self): + def _reset_session(self) -> None: self._session_id = None self._session_passwd = b"\x00" * 16 @property - def client_state(self): + def client_state(self) -> KeeperState: """Returns the last Zookeeper client state This is the non-simplified state information and is generally @@ -442,7 +586,7 @@ def client_state(self): return self._state @property - def client_id(self): + def client_id(self) -> tuple[int | None, bytes] | None: """Returns the client id for this Zookeeper session if connected. @@ -455,12 +599,16 @@ def client_id(self): return None @property - def connected(self): + def connected(self) -> bool: """Returns whether the Zookeeper connection has been established.""" return self._live.is_set() - def set_hosts(self, hosts, randomize_hosts=None): + def set_hosts( + self, + hosts: str | list[str], + randomize_hosts: bool | None = None, + ) -> None: """sets the list of hosts used by this client. This function accepts the same format hosts parameter as the init @@ -504,7 +652,7 @@ def set_hosts(self, hosts, randomize_hosts=None): self.chroot = new_chroot - def add_listener(self, listener): + def add_listener(self, listener: ListenerFunc) -> None: """Add a function to be called for connection state changes. This function will be called with a @@ -519,15 +667,20 @@ def add_listener(self, listener): should be used so that the listener can return immediately. """ - if not (listener and callable(listener)): + # This check should be unnecessary but protects against people who are + # not using type checkers and accidentally passing in something that + # isn't callable. It should be removed. + if not ( + listener and callable(listener) # type: ignore[truthy-function] + ): raise ConfigurationError("listener must be callable") self.state_listeners.add(listener) - def remove_listener(self, listener): + def remove_listener(self, listener: ListenerFunc) -> None: """Remove a listener function""" self.state_listeners.discard(listener) - def _make_state_change(self, state): + def _make_state_change(self, state: KazooState) -> None: # skip if state is current if self.state == state: return @@ -544,7 +697,7 @@ def _make_state_change(self, state): except Exception: self.logger.exception("Error in connection state listener") - def _session_callback(self, state): + def _session_callback(self, state: KeeperState) -> None: if state == self._state: return @@ -581,9 +734,10 @@ def _session_callback(self, state): self._make_state_change(KazooState.SUSPENDED) self._reset_watchers() - def _notify_pending(self, state): + def _notify_pending(self, state: KeeperState) -> None: """Used to clear a pending response queue and request queue during connection drops.""" + exc: KazooException if state == KeeperState.AUTH_FAILED: exc = AuthFailedError() elif state == KeeperState.EXPIRED_SESSION: @@ -607,7 +761,7 @@ def _notify_pending(self, state): except IndexError: break - def _safe_close(self): + def _safe_close(self) -> None: self.handler.stop() timeout = self._session_timeout // 1000 if timeout < 10: @@ -618,7 +772,9 @@ def _safe_close(self): "and wouldn't close after %s seconds" % timeout ) - def _call(self, request, async_object): + def _call( + self, request: object, async_object: IAsyncResult + ) -> bool | None: """Ensure the client is in CONNECTED or SUSPENDED state and put the request in the queue if it is. @@ -647,14 +803,16 @@ def _call(self, request, async_object): async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None try: write_sock.send(b"\0") except: # NOQA async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None - def start(self, timeout=15): + def start(self, timeout: float = 15.0) -> None: """Initiate connection to ZK. :param timeout: Time in seconds to wait for connection to @@ -678,7 +836,7 @@ def start(self, timeout=15): "should be created before normal use." ) - def start_async(self): + def start_async(self) -> Event: """Asynchronously initiate connection to ZK. :returns: An event object that can be checked to see if the @@ -705,7 +863,7 @@ def start_async(self): self._connection.start() return self._live - def stop(self): + def stop(self) -> None: """Gracefully stop this Zookeeper session. This method can be called while a reconnection attempt is in @@ -721,18 +879,22 @@ def stop(self): return self._stopped.set() - self._queue.append((CloseInstance, None)) + self._queue.append((CloseInstance, cast("IAsyncResult", None))) try: - self._connection._write_sock.send(b"\0") + # This assert should never fail since the connection should + # have been started but I'm not sure how to persaude mypy of that + self._connection._write_sock.send( # type: ignore[union-attr] + b"\0" + ) finally: self._safe_close() - def restart(self): + def restart(self) -> None: """Stop and restart the Zookeeper session.""" self.stop() self.start() - def close(self): + def close(self) -> None: """Free any resources held by the client. This method should be called on a stopped client before it is @@ -742,7 +904,7 @@ def close(self): """ self._connection.close() - def command(self, cmd=b"ruok"): + def command(self, cmd: bytes = b"ruok") -> str: """Sent a management command to the current ZK server. Examples are `ruok`, `envi` or `stat`. @@ -761,8 +923,18 @@ def command(self, cmd=b"ruok"): if not self._live.is_set(): raise ConnectionLoss("No connection to server") - peer = self._connection._socket.getpeername()[:2] - peer_host = self._connection._socket.getpeername()[1] + # Need a way of persauding mypy that the connection is live and thus + # the socket is not None + peer = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + :2 + ] + ) + peer_host = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + 1 + ] + ) sock = self.handler.create_connection( peer, hostname=peer_host, @@ -780,7 +952,7 @@ def command(self, cmd=b"ruok"): sock.close() return result.decode("utf-8", "replace") - def server_version(self, retries=3): + def server_version(self, retries: int = 3) -> tuple[int, ...]: """Get the version of the currently connected ZK server. :returns: The server version, for example (3, 4, 3). @@ -790,7 +962,7 @@ def server_version(self, retries=3): """ - def _try_fetch(): + def _try_fetch() -> tuple[int, ...] | None: data = self.command(b"envi") data_parsed = {} for line in data.splitlines(): @@ -804,13 +976,19 @@ def _try_fetch(): if k: data_parsed[k] = v version = data_parsed.get(ENVI_VERSION_KEY, "") - version_digits = ENVI_VERSION.match(version).group(1) + # FIXME If you get an unexpected answer, you'll crash - not + # changing the code, so just ignoring the type error + version_digits = ENVI_VERSION.match( + version + ).group( # type: ignore[union-attr] + 1 + ) try: return tuple([int(d) for d in version_digits.split(".")]) except ValueError: return None - def _is_valid(version): + def _is_valid(version: tuple[int, ...] | None) -> bool: # All zookeeper versions should have at least major.minor # version numbers; if we get one that doesn't it is likely not # correct and was truncated... @@ -818,21 +996,29 @@ def _is_valid(version): return True return False + # FIXME A better way of doing this would be to put the initial + # _try_fetch in the loop and inline _is_valid but I want to minimise + # code changes + # Try 1 + retries amount of times to get a version that we know # will likely be acceptable... version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + # and the next 2 suppress should include return-value + # but hound is broken + return version # type: ignore for _i in range(0, retries): version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + return version # type: ignore raise KazooException( "Unable to fetch useable server" " version after trying %s times" % (1 + max(0, retries)) ) - def add_auth(self, scheme, credential): + def add_auth(self, scheme: str, credential: str) -> bool: """Send credentials to server. :param scheme: authentication scheme (default supported: @@ -847,9 +1033,9 @@ def add_auth(self, scheme, credential): the session state will be set to AUTH_FAILED as well. """ - return self.add_auth_async(scheme, credential).get() + return cast("bool", self.add_auth_async(scheme, credential).get()) - def add_auth_async(self, scheme, credential): + def add_auth_async(self, scheme: str, credential: str) -> IAsyncResult: """Asynchronously send credentials to server. Takes the same arguments as :meth:`add_auth`. @@ -868,7 +1054,7 @@ def add_auth_async(self, scheme, credential): self._call(Auth(0, scheme, credential), async_result) return async_result - def unchroot(self, path): + def unchroot(self, path: str) -> str: """Strip the chroot if applicable from the path.""" if not self.chroot: return path @@ -879,7 +1065,7 @@ def unchroot(self, path): else: return path - def sync_async(self, path): + def sync_async(self, path: str) -> IAsyncResult: """Asynchronous sync. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -888,10 +1074,10 @@ def sync_async(self, path): async_result = self.handler.async_result() @wrap(async_result) - def _sync_completion(result): + def _sync_completion(result: IAsyncResult) -> str: return self.unchroot(result.get()) - def _do_sync(): + def _do_sync() -> None: result = self.handler.async_result() self._call(Sync(_prefix_root(self.chroot, path)), result) result.rawlink(_sync_completion) @@ -899,7 +1085,7 @@ def _do_sync(): _do_sync() return async_result - def sync(self, path): + def sync(self, path: str) -> str: """Sync, blocks until response is acknowledged. Flushes channel between process and leader. @@ -913,18 +1099,44 @@ def sync(self, path): .. versionadded:: 0.5 """ - return self.sync_async(path).get() + return cast("str", self.sync_async(path).get()) + + @overload + def create( + self, + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[False] = False, + ) -> str: + ... + + @overload + def create( + self, + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[True] = True, + ) -> tuple[str, ZnodeStat]: + ... def create( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> str | tuple[str, ZnodeStat]: """Create a node with the given value as its data. Optionally set an ACL on the node. @@ -1003,26 +1215,29 @@ def create( The `include_data` option. """ acl = acl or self.default_acl - return self.create_async( - path, - value, - acl=acl, - ephemeral=ephemeral, - sequence=sequence, - makepath=makepath, - include_data=include_data, - ).get() + return cast( + "str | tuple[str, ZnodeStat]", + self.create_async( + path, + value, + acl=acl, + ephemeral=ephemeral, + sequence=sequence, + makepath=makepath, + include_data=include_data, + ).get(), + ) def create_async( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously create a ZNode. Takes the same arguments as :meth:`create`. @@ -1066,11 +1281,16 @@ def create_async( async_result = self.handler.async_result() @capture_exceptions(async_result) - def do_create(): + def do_create() -> None: result = self._create_async_inner( path, value, - acl, + # The way acl is constructed ends up confusing mypy, which + # thinks that acl can be None here, even though the code + # above ensures that if acl is None, it gets set to + # OPEN_ACL_UNSAFE, so we ignore the type error here. + # behaves differently in python3.8 and python3.14, sigh. + acl, # type: ignore[arg-type] flags, trailing=sequence, include_data=include_data, @@ -1078,12 +1298,14 @@ def do_create(): result.rawlink(create_completion) @capture_exceptions(async_result) - def retry_completion(result): + def retry_completion(result: IAsyncResult) -> None: result.get() do_create() @wrap(async_result) - def create_completion(result): + def create_completion( + result: IAsyncResult, + ) -> str | tuple[str, ZnodeStat] | None: try: if include_data: new_path, stat = result.get() @@ -1098,18 +1320,22 @@ def create_completion(result): else: parent, _ = split(path) self.ensure_path_async(parent, acl).rawlink(retry_completion) + return None do_create() return async_result def _create_async_inner( - self, path, value, acl, flags, trailing=False, include_data=False - ): + self, + path: str, + value: bytes | None, + acl: Sequence[ACL], + flags: int, + trailing: bool = False, + include_data: bool = False, + ) -> IAsyncResult: async_result = self.handler.async_result() - if include_data: - opcode = Create2 - else: - opcode = Create + opcode = Create2 if include_data else Create call_result = self._call( opcode( @@ -1126,19 +1352,24 @@ def _create_async_inner( # exception upwards to the do_create function in # KazooClient.create so that it gets set on the correct # async_result object - raise async_result.exception + # Note: Do we actually need call_result? It seems like we could + # just check the state of the exception, and avoid the typing + # stuff. + raise async_result.exception # type: ignore[misc] return async_result - def ensure_path(self, path, acl=None): + def ensure_path(self, path: str, acl: Sequence[ACL] | None = None) -> bool: """Recursively create a path if it doesn't exist. :param path: Path of node. :param acl: Permissions for node. """ - return self.ensure_path_async(path, acl).get() + return cast("bool", self.ensure_path_async(path, acl).get()) - def ensure_path_async(self, path, acl=None): + def ensure_path_async( + self, path: str, acl: Sequence[ACL] | None = None + ) -> IAsyncResult: """Recursively create a path asynchronously if it doesn't exist. Takes the same arguments as :meth:`ensure_path`. @@ -1151,19 +1382,21 @@ def ensure_path_async(self, path, acl=None): async_result = self.handler.async_result() @wrap(async_result) - def create_completion(result): + def create_completion(result: IAsyncResult) -> bool: try: - return result.get() + return cast("bool", result.get()) except NodeExistsError: return True @capture_exceptions(async_result) - def prepare_completion(next_path, result): + def prepare_completion(next_path: str, result: IAsyncResult) -> None: result.get() self.create_async(next_path, acl=acl).rawlink(create_completion) @wrap(async_result) - def exists_completion(path, result): + def exists_completion( + path: str, result: IAsyncResult + ) -> Literal[True] | None: if result.get(): return True parent, node = split(path) @@ -1173,12 +1406,15 @@ def exists_completion(path, result): ) else: self.create_async(path, acl=acl).rawlink(create_completion) + return None self.exists_async(path).rawlink(partial(exists_completion, path)) return async_result - def exists(self, path, watch=None): + def exists( + self, path: str, watch: WatchFunc | None = None + ) -> ZnodeStat | None: """Check if a node exists. If a watch is provided, it will be left on the node with the @@ -1198,9 +1434,13 @@ def exists(self, path, watch=None): returns a non-zero error code. """ - return self.exists_async(path, watch=watch).get() + return cast( + "ZnodeStat | None", self.exists_async(path, watch=watch).get() + ) - def exists_async(self, path, watch=None): + def exists_async( + self, path: str, watch: WatchFunc | None = None + ) -> IAsyncResult: """Asynchronously check if a node exists. Takes the same arguments as :meth:`exists`. @@ -1218,7 +1458,9 @@ def exists_async(self, path, watch=None): ) return async_result - def get(self, path, watch=None): + def get( + self, path: str, watch: WatchFunc | None = None + ) -> tuple[bytes, ZnodeStat]: """Get the value of a node. If a watch is provided, it will be left on the node with the @@ -1241,9 +1483,13 @@ def get(self, path, watch=None): returns a non-zero error code """ - return self.get_async(path, watch=watch).get() + return cast( + "tuple[bytes, ZnodeStat]", self.get_async(path, watch=watch).get() + ) - def get_async(self, path, watch=None): + def get_async( + self, path: str, watch: WatchFunc | None = None + ) -> IAsyncResult: """Asynchronously get the value of a node. Takes the same arguments as :meth:`get`. @@ -1261,7 +1507,30 @@ def get_async(self, path, watch=None): ) return async_result - def get_children(self, path, watch=None, include_data=False): + @overload + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: Literal[False] = False, + ) -> list[str]: + ... + + @overload + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: Literal[True] = True, + ) -> tuple[list[str], ZnodeStat]: + ... + + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: bool = False, + ) -> list[str] | tuple[list[str], ZnodeStat]: """Get a list of child nodes of a path. If a watch is provided it will be left on the node with the @@ -1295,11 +1564,19 @@ def get_children(self, path, watch=None, include_data=False): The `include_data` option. """ - return self.get_children_async( - path, watch=watch, include_data=include_data - ).get() + return cast( + "list[str] | tuple[list[str], ZnodeStat]", + self.get_children_async( + path, watch=watch, include_data=include_data + ).get(), + ) - def get_children_async(self, path, watch=None, include_data=False): + def get_children_async( + self, + path: str, + watch: WatchFunc | None = None, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously get a list of child nodes of a path. Takes the same arguments as :meth:`get_children`. @@ -1314,6 +1591,8 @@ def get_children_async(self, path, watch=None, include_data=False): raise TypeError("Invalid type for 'include_data' (bool expected)") async_result = self.handler.async_result() + # FIXME? Do this as req = getc2 if include_data else getc + req: GetChildren | GetChildren2 if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) else: @@ -1321,7 +1600,7 @@ def get_children_async(self, path, watch=None, include_data=False): self._call(req, async_result) return async_result - def get_acls(self, path): + def get_acls(self, path: str) -> tuple[list[ACL], ZnodeStat]: """Return the ACL and stat of the node of the given path. :param path: Path of the node. @@ -1339,9 +1618,11 @@ def get_acls(self, path): .. versionadded:: 0.5 """ - return self.get_acls_async(path).get() + return cast( + "tuple[list[ACL], ZnodeStat]", self.get_acls_async(path).get() + ) - def get_acls_async(self, path): + def get_acls_async(self, path: str) -> IAsyncResult: """Return the ACL and stat of the node of the given path. Takes the same arguments as :meth:`get_acls`. @@ -1355,7 +1636,9 @@ def get_acls_async(self, path): self._call(GetACL(_prefix_root(self.chroot, path)), async_result) return async_result - def set_acls(self, path, acls, version=-1): + def set_acls( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> ZnodeStat: """Set the ACL for the node of the given path. Set the ACL for the node of the given path if such a node @@ -1382,9 +1665,13 @@ def set_acls(self, path, acls, version=-1): .. versionadded:: 0.5 """ - return self.set_acls_async(path, acls, version).get() + return cast( + "ZnodeStat", self.set_acls_async(path, acls, version).get() + ) - def set_acls_async(self, path, acls, version=-1): + def set_acls_async( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> IAsyncResult: """Set the ACL for the node of the given path. Takes the same arguments as :meth:`set_acls`. @@ -1407,7 +1694,9 @@ def set_acls_async(self, path, acls, version=-1): ) return async_result - def set(self, path, value, version=-1): + def set( + self, path: str, value: bytes | None, version: int = -1 + ) -> ZnodeStat: """Set the value of a node. If the version of the node being updated is newer than the @@ -1440,9 +1729,11 @@ def set(self, path, value, version=-1): returns a non-zero error code. """ - return self.set_async(path, value, version).get() + return cast("ZnodeStat", self.set_async(path, value, version).get()) - def set_async(self, path, value, version=-1): + def set_async( + self, path: str, value: bytes | None, version: int = -1 + ) -> IAsyncResult: """Set the value of a node. Takes the same arguments as :meth:`set`. @@ -1463,7 +1754,7 @@ def set_async(self, path, value, version=-1): ) return async_result - def transaction(self): + def transaction(self) -> TransactionRequest: """Create and return a :class:`TransactionRequest` object Creates a :class:`TransactionRequest` object. A Transaction can @@ -1480,7 +1771,12 @@ def transaction(self): """ return TransactionRequest(self) - def delete(self, path, version=-1, recursive=False): + # This should not return anything. No return value is documented, and the + # two called functions return different things. AFAICT. But for now I + # want to minimise code changes + def delete( + self, path: str, version: int = -1, recursive: bool = False + ) -> Any: """Delete a node. The call will succeed if such a node exists, and the given @@ -1518,7 +1814,7 @@ def delete(self, path, version=-1, recursive=False): else: return self.delete_async(path, version).get() - def delete_async(self, path, version=-1): + def delete_async(self, path: str, version: int = -1) -> IAsyncResult: """Asynchronously delete a node. Takes the same arguments as :meth:`delete`, with the exception of `recursive`. @@ -1535,7 +1831,7 @@ def delete_async(self, path, version=-1): ) return async_result - def _delete_recursive(self, path): + def _delete_recursive(self, path: str) -> Literal[True] | None: try: children = self.get_children(path) except NoNodeError: @@ -1553,8 +1849,15 @@ def _delete_recursive(self, path): self.delete(path) except NoNodeError: # pragma: nocover pass + return None - def reconfig(self, joining, leaving, new_members, from_config=-1): + def reconfig( + self, + joining: str | None, + leaving: str | None, + new_members: str | None, + from_config: int = -1, + ) -> tuple[bytes, ZnodeStat]: """Reconfig a cluster. This call will succeed if the cluster was reconfigured accordingly. @@ -1625,9 +1928,15 @@ def reconfig(self, joining, leaving, new_members, from_config=-1): result = self.reconfig_async( joining, leaving, new_members, from_config ) - return result.get() + return cast("tuple[bytes, ZnodeStat]", result.get()) - def reconfig_async(self, joining, leaving, new_members, from_config): + def reconfig_async( + self, + joining: str | None, + leaving: str | None, + new_members: str | None, + from_config: int, + ) -> IAsyncResult: """Asynchronously reconfig a cluster. Takes the same arguments as :meth:`reconfig`. @@ -1674,14 +1983,19 @@ class TransactionRequest(object): """ - def __init__(self, client): + def __init__(self, client: KazooClient): self.client = client - self.operations = [] + self.operations: list[Any] = [] self.committed = False def create( - self, path, value=b"", acl=None, ephemeral=False, sequence=False - ): + self, + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + ) -> None: """Add a create ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.create`, with the exception of `makepath`. @@ -1718,7 +2032,7 @@ def create( None, ) - def delete(self, path, version=-1): + def delete(self, path: str, version: int = -1) -> None: """Add a delete ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.delete`, with the exception of `recursive`. @@ -1730,7 +2044,7 @@ def delete(self, path, version=-1): raise TypeError("Invalid type for 'version' (int expected)") self._add(Delete(_prefix_root(self.client.chroot, path), version)) - def set_data(self, path, value, version=-1): + def set_data(self, path: str, value: bytes, version: int = -1) -> None: """Add a set ZNode value to the transaction. Takes the same arguments as :meth:`KazooClient.set`. @@ -1745,7 +2059,7 @@ def set_data(self, path, value, version=-1): SetData(_prefix_root(self.client.chroot, path), value, version) ) - def check(self, path, version): + def check(self, path: str, version: int) -> None: """Add a Check Version to the transaction. This command will fail and abort a transaction if the path @@ -1760,7 +2074,7 @@ def check(self, path, version): CheckVersion(_prefix_root(self.client.chroot, path), version) ) - def commit_async(self): + def commit_async(self) -> IAsyncResult: """Commit the transaction asynchronously. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -1772,28 +2086,38 @@ def commit_async(self): self.client._call(Transaction(self.operations), async_object) return async_object - def commit(self): + def commit(self) -> list[Any]: """Commit the transaction. :returns: A list of the results for each operation in the transaction. """ - return self.commit_async().get() + return cast("list[Any]", self.commit_async().get()) - def __enter__(self): + def __enter__(self) -> TransactionRequest: return self - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: """Commit and cleanup accumulated transaction data.""" if not exc_type: self.commit() + return None - def _check_tx_state(self): + def _check_tx_state(self) -> None: if self.committed: raise ValueError("Transaction already committed") - def _add(self, request, post_processor=None): + def _add( + self, + request: Any, + post_processor: Callable[[Any], Any] | None = None, + ) -> None: self._check_tx_state() self.client.logger.log(BLATHER, "Added %r to %r", request, self) self.operations.append(request) diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index b24c697c..ae9341df 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -1,6 +1,9 @@ """Kazoo Exceptions""" +from __future__ import annotations + from collections import defaultdict +from typing import Callable, Type class KazooException(Exception): @@ -51,17 +54,25 @@ class SASLException(KazooException): """ -def _invalid_error_code(): +def _invalid_error_code() -> Type[ZookeeperError]: raise RuntimeError("Invalid error code") -EXCEPTIONS = defaultdict(_invalid_error_code) +EXCEPTIONS: defaultdict[int, Type[ZookeeperError]] = defaultdict( + _invalid_error_code +) -def _zookeeper_exception(code): - def decorator(klass): +def _zookeeper_exception( + code: int, +) -> Callable[[Type[ZookeeperError]], Type[ZookeeperError]]: + def decorator(klass: Type[ZookeeperError]) -> Type[ZookeeperError]: EXCEPTIONS[code] = klass - klass.code = code + # Unfortunately there is currently no good of doing the assignment here + # in a way that type checkers would allow. It's a known problem (see + # https://discuss.python.org/t/how-to-type-hint-a-class-decorator/63010 + # ) + klass.code = code # type: ignore[attr-defined] return klass return decorator diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 8869cc57..ee8ed47d 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -1,10 +1,14 @@ """A eventlet based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import contextlib import logging +from typing import cast, Any, Generator, TYPE_CHECKING + import eventlet from eventlet.green import socket as green_socket from eventlet.green import time as green_time @@ -15,6 +19,18 @@ from kazoo.handlers import utils from kazoo.handlers.utils import selector_select +if TYPE_CHECKING: + from kazoo.interfaces import ( + Event, + FdLike, + IHandler, + Lockable, + ReentrantLock, + Socket, + ) + from kazoo.protocol.states import Callback + + LOG = logging.getLogger(__name__) # sentinel objects @@ -22,17 +38,17 @@ @contextlib.contextmanager -def _yield_before_after(): +def _yield_before_after() -> Generator[None, None, None]: # Yield to any other co-routines... # # See: http://eventlet.net/doc/modules/greenthread.html # for how this zero sleep is really a cooperative yield to other potential # co-routines... - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] try: yield finally: - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] class TimeoutError(Exception): @@ -42,12 +58,15 @@ class TimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: IHandler): super(AsyncResult, self).__init__( - handler, green_threading.Condition, TimeoutError + handler, + green_threading.Condition, # type: ignore[attr-defined] + TimeoutError, ) +# FIXME This should inherit from IHandler class SequentialEventletHandler(object): """Eventlet handler for sequentially executing callbacks. @@ -81,26 +100,35 @@ class SequentialEventletHandler(object): queue_impl = green_queue.LightQueue queue_empty = green_queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialEventletHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() - self._workers = [] + self.callback_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self.completion_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self._workers: list[ # type: ignore[name-defined] + tuple[ + eventlet.GreenThread, + green_queue.LightQueue, + ] + ] = [] self._started = False @staticmethod - def sleep_func(wait): - green_time.sleep(wait) + def sleep_func(wait: float) -> None: + green_time.sleep(wait) # type: ignore[attr-defined, no-untyped-call] @property - def running(self): + def running(self) -> bool: return self._started timeout_exception = TimeoutError - def _process_completion_queue(self): + def _process_completion_queue(self) -> None: while True: - cb = self.completion_queue.get() + cb = self.completion_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -114,9 +142,9 @@ def _process_completion_queue(self): finally: del cb # release before possible idle - def _process_callback_queue(self): + def _process_callback_queue(self) -> None: while True: - cb = self.callback_queue.get() + cb = self.callback_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -130,58 +158,83 @@ def _process_callback_queue(self): finally: del cb # release before possible idle - def start(self): + def start(self) -> None: if not self._started: # Spawn our worker threads, we have # - A callback worker for watch events to be called # - A completion worker for completion events to be called - w = eventlet.spawn(self._process_completion_queue) + w = eventlet.spawn( + self._process_completion_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.completion_queue)) - w = eventlet.spawn(self._process_callback_queue) + w = eventlet.spawn( + self._process_callback_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.callback_queue)) self._started = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: while self._workers: w, q = self._workers.pop() - q.put(_STOP) + q.put(_STOP) # type: ignore[no-untyped-call] w.wait() self._started = False atexit.unregister(self.stop) - def socket(self, *args, **kwargs): + def socket(self) -> Socket: return utils.create_tcp_socket(green_socket) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(green_socket) - def event_object(self): - return green_threading.Event() + def event_object(self) -> Event: + return cast( + "Event", green_threading.Event() # type: ignore[attr-defined] + ) - def lock_object(self): - return green_threading.Lock() + def lock_object(self) -> Lockable: + return cast( + "Lockable", green_threading.Lock() # type: ignore[attr-defined] + ) - def rlock_object(self): - return green_threading.RLock() + def rlock_object(self) -> ReentrantLock: + return cast( + "ReentrantLock", + green_threading.RLock(), # type: ignore[attr-defined] + ) - def create_connection(self, *args, **kwargs): + # FIXME fix parameters + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(green_socket, *args, **kwargs) - def select(self, *args, **kwargs): + # FIXME fix parameters + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: with _yield_before_after(): + # Following appears to be a bug in mypy (see + # https://github.com/python/mypy/issues/6799) return selector_select( - *args, selectors_module=green_selectors, **kwargs + *args, + selectors_module=green_selectors, # type: ignore[misc] + **kwargs, ) - def async_result(self): - return AsyncResult(self) + def async_result(self) -> AsyncResult: + return AsyncResult(self) # type: ignore[arg-type] - def spawn(self, func, *args, **kwargs): - t = green_threading.Thread(target=func, args=args, kwargs=kwargs) + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> green_threading.Thread: # type: ignore[name-defined] + t = green_threading.Thread( # type: ignore[attr-defined] + target=func, args=args, kwargs=kwargs + ) t.daemon = True t.start() return t - def dispatch_callback(self, callback): - self.callback_queue.put(lambda: callback.func(*callback.args)) + def dispatch_callback(self, callback: Callback) -> None: + self.callback_queue.put( # type: ignore[no-untyped-call] + lambda: callback.func(*callback.args) + ) diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index f36389aa..7a3c215c 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -1,9 +1,13 @@ """A gevent based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import logging +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast + import gevent from gevent import socket import gevent.event @@ -14,18 +18,30 @@ from kazoo.handlers.utils import selector_select +from gevent import Greenlet from gevent.lock import Semaphore, RLock from kazoo.handlers import utils +if TYPE_CHECKING: + from kazoo.interfaces import FdLike, Lockable, Socket + from kazoo.protocol.states import Callback + _using_libevent = gevent.__version__.startswith("0.") log = logging.getLogger(__name__) + _STOP = object() AsyncResult = gevent.event.AsyncResult +# The following would be great typenames, but python3.8 complains about type +# objects not being indexable. +# GCallback = Callable[..., None] +# Worker = Greenlet[..., Any] +# CallbackQueue = gevent.queue.Queue[Callable[..., None]] + class SequentialGeventHandler(object): """Gevent handler for sequentially executing callbacks. @@ -53,24 +69,28 @@ class SequentialGeventHandler(object): queue_empty = gevent.queue.Empty sleep_func = staticmethod(gevent.sleep) - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialGeventHandler` instance""" - self.callback_queue = self.queue_impl() + self.callback_queue: gevent.queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._async = None self._state_change = Semaphore() - self._workers = [] + self._workers: list[Greenlet[..., Any]] = [] @property - def running(self): + def running(self) -> bool: return self._running class timeout_exception(gevent.Timeout): - def __init__(self, msg): + def __init__(self, msg: Any): gevent.Timeout.__init__(self, exception=msg) - def _create_greenlet_worker(self, queue): - def greenlet_worker(): + def _create_greenlet_worker( + self, queue: gevent.queue.Queue[Callable[..., None]] + ) -> Greenlet[..., Any]: + def greenlet_worker() -> None: while True: try: func = queue.get() @@ -88,7 +108,7 @@ def greenlet_worker(): return gevent.spawn(greenlet_worker) - def start(self): + def start(self) -> None: """Start the greenlet workers.""" with self._state_change: if self._running: @@ -98,12 +118,13 @@ def start(self): # Spawn our worker greenlets, we have # - A callback worker for watch events to be called + # FIXME Why the loop? for queue in (self.callback_queue,): w = self._create_greenlet_worker(queue) self._workers.append(w) atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the greenlet workers and empty all queues.""" with self._state_change: if not self._running: @@ -112,7 +133,7 @@ def stop(self): self._running = False for queue in (self.callback_queue,): - queue.put(_STOP) + queue.put(cast("Callable[..., None]", _STOP)) while self._workers: worker = self._workers.pop() @@ -123,33 +144,45 @@ def stop(self): atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: + # FIXME use the correct arguments, not *args, *kwargs return selector_select( - *args, selectors_module=gevent.selectors, **kwargs + # Likely a bug in mypy (see + # https://github.com/python/mypy/issues/6799) + *args, + selectors_module=gevent.selectors, + **kwargs, # type: ignore[misc] ) - def socket(self, *args, **kwargs): + def socket(self) -> Socket: + # See above return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + # See above return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> gevent.event.Event: """Create an appropriate Event object""" return gevent.event.Event() - def lock_object(self): + def lock_object(self) -> Lockable: """Create an appropriate Lock object""" - return gevent.thread.allocate_lock() + return cast( + "Lockable", + gevent.thread.allocate_lock(), # type: ignore[no-untyped-call] + ) - def rlock_object(self): + def rlock_object(self) -> RLock: """Create an appropriate RLock object""" return RLock() - def async_result(self): + def async_result(self) -> AsyncResult[Any]: """Create a :class:`AsyncResult` instance The :class:`AsyncResult` instance will have its completion @@ -160,11 +193,13 @@ def async_result(self): """ return AsyncResult() - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> gevent.Greenlet[..., Any]: """Spawn a function to run asynchronously""" return gevent.spawn(func, *args, **kwargs) - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index b9acd875..829a7010 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -10,6 +10,8 @@ :class:`~kazoo.handlers.gevent.SequentialGeventHandler` instead. """ + +from __future__ import annotations from __future__ import absolute_import import atexit @@ -19,9 +21,22 @@ import threading import time +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast + from kazoo.handlers import utils from kazoo.handlers.utils import selector_select - +from kazoo.interfaces import IHandler + +if TYPE_CHECKING: + from kazoo.interfaces import ( + Event, + FdLike, + Lockable, + ReentrantLock, + Socket, + SpawnedFunc, + ) + from kazoo.protocol.states import Callback # sentinel objects _STOP = object() @@ -29,7 +44,7 @@ log = logging.getLogger(__name__) -def _to_fileno(obj): +def _to_fileno(obj: FdLike) -> int: if isinstance(obj, int): fd = int(obj) elif hasattr(obj, "fileno"): @@ -55,13 +70,13 @@ class KazooTimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: Any) -> None: super(AsyncResult, self).__init__( handler, threading.Condition, KazooTimeoutError ) -class SequentialThreadingHandler(object): +class SequentialThreadingHandler(IHandler): """Threading handler for sequentially executing callbacks. This handler executes callbacks in a sequential manner. A queue is @@ -96,20 +111,26 @@ class SequentialThreadingHandler(object): queue_impl = queue.Queue queue_empty = queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialThreadingHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() + self.callback_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() + self.completion_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._state_change = threading.Lock() - self._workers = [] + self._workers: list[threading.Thread] = [] @property - def running(self): + def running(self) -> bool: return self._running - def _create_thread_worker(self, work_queue): - def _thread_worker(): # pragma: nocover + def _create_thread_worker( + self, work_queue: queue.Queue[Callable[..., None]] + ) -> threading.Thread: + def _thread_worker() -> None: # pragma: nocover while True: try: func = work_queue.get() @@ -128,7 +149,7 @@ def _thread_worker(): # pragma: nocover t = self.spawn(_thread_worker) return t - def start(self): + def start(self) -> None: """Start the worker threads.""" with self._state_change: if self._running: @@ -143,7 +164,7 @@ def start(self): self._running = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the worker threads and empty all queues.""" with self._state_change: if not self._running: @@ -152,7 +173,7 @@ def stop(self): self._running = False for work_queue in (self.completion_queue, self.callback_queue): - work_queue.put(_STOP) + work_queue.put(cast("Callable[..., None]", _STOP)) self._workers.reverse() while self._workers: @@ -164,41 +185,47 @@ def stop(self): self.completion_queue = self.queue_impl() atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: return selector_select(*args, **kwargs) - def socket(self): + def socket(self) -> Socket: return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> Event: """Create an appropriate Event object""" return threading.Event() - def lock_object(self): + def lock_object(self) -> Lockable: """Create a lock object""" - return threading.Lock() + # Note: This is not ideal, but the ContextManager Protocol seems to + # think you should return an object of the same type. + return cast("Lockable", threading.Lock()) - def rlock_object(self): + def rlock_object(self) -> ReentrantLock: """Create an appropriate RLock object""" - return threading.RLock() + return cast("ReentrantLock", threading.RLock()) - def async_result(self): + def async_result(self) -> AsyncResult: """Create a :class:`AsyncResult` instance""" return AsyncResult(self) - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> threading.Thread: t = threading.Thread(target=func, args=args, kwargs=kwargs) t.daemon = True t.start() return t - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 206806f6..019cc2d2 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -1,5 +1,7 @@ """Kazoo handler helpers""" +from __future__ import annotations + from collections import defaultdict import errno import functools @@ -8,6 +10,14 @@ import ssl import socket import time +from types import ModuleType +from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING +from typing_extensions import ParamSpec + +from kazoo.interfaces import IAsyncResult, FdLike + +if TYPE_CHECKING: + from kazoo.interfaces import Socket HAS_FNCTL = True try: @@ -15,36 +25,51 @@ except ImportError: # pragma: nocover HAS_FNCTL = False + # sentinel objects +# Note: This needs to be a unique object that is not None, as None is used to +# indicate a successful result in AsyncResult. +# This should probably be an Enum, it would certainly be cleaner, but don't +# want to change the code too much. _NONE = object() +CallbackFunc = Callable[..., None] -class AsyncResult(object): + +class AsyncResult(IAsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler, condition_factory, timeout_factory): + def __init__( + self, + handler: Any, + condition_factory: Callable[[], Any], + timeout_factory: Callable[[], Any], + ) -> None: self._handler = handler - self._exception = _NONE + self._exception: object | Exception | None = _NONE self._condition = condition_factory() - self._callbacks = [] + self._callbacks: list[CallbackFunc] = [] self._timeout_factory = timeout_factory self.value = None - def ready(self): + def ready(self) -> bool: """Return true if and only if it holds a value or an exception""" return self._exception is not _NONE - def successful(self): + def successful(self) -> bool: """Return true if and only if it is ready and holds a value""" return self._exception is None @property - def exception(self): + def exception(self) -> Exception | None: if self._exception is not _NONE: - return self._exception + # The next line should have return-value, but hound ci + # is frankly nothing but a hound dog + return self._exception # type: ignore + return None - def set(self, value=None): + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters.""" with self._condition: self.value = value @@ -52,14 +77,14 @@ def set(self, value=None): self._do_callbacks() self._condition.notify_all() - def set_exception(self, exception): + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters.""" with self._condition: self._exception = exception self._do_callbacks() self._condition.notify_all() - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception. If there is no value raises TimeoutError. @@ -69,18 +94,18 @@ def get(self, block=True, timeout=None): if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] elif block: self._condition.wait(timeout) if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] # if we get to this point we timeout raise self._timeout_factory() - def get_nowait(self): + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raises TimeoutError @@ -88,14 +113,14 @@ def get_nowait(self): """ return self.get(block=False) - def wait(self, timeout=None): + def wait(self, timeout: float | None = None) -> bool: """Block until the instance is ready.""" with self._condition: if not self.ready(): self._condition.wait(timeout) return self._exception is not _NONE - def rawlink(self, callback): + def rawlink(self, callback: CallbackFunc) -> None: """Register a callback to call when a value or an exception is set""" with self._condition: @@ -106,7 +131,7 @@ def rawlink(self, callback): if self.ready(): self._do_callbacks() - def unlink(self, callback): + def unlink(self, callback: CallbackFunc) -> None: """Remove the callback set by :meth:`rawlink`""" with self._condition: if self.ready(): @@ -116,7 +141,7 @@ def unlink(self, callback): if callback in self._callbacks: self._callbacks.remove(callback) - def _do_callbacks(self): + def _do_callbacks(self) -> None: """Execute the callbacks that were registered by :meth:`rawlink`. If the handler is in running state this method only schedules the calls to be performed by the handler. If it's stopped, @@ -131,19 +156,21 @@ def _do_callbacks(self): functools.partial(callback, self)() -def _set_fd_cloexec(fd): +def _set_fd_cloexec(fd: Socket) -> None: flags = fcntl.fcntl(fd, fcntl.F_GETFD) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) -def _set_default_tcpsock_options(module, sock): +def _set_default_tcpsock_options(module: ModuleType, sock: Socket) -> Socket: sock.setsockopt(module.IPPROTO_TCP, module.TCP_NODELAY, 1) if HAS_FNCTL: _set_fd_cloexec(sock) return sock -def create_socket_pair(module, port=0): +def create_socket_pair( + module: ModuleType, port: int = 0 +) -> tuple[Socket, Socket]: """Create socket pair. If socket.socketpair isn't available, we emulate it. @@ -182,32 +209,32 @@ def create_socket_pair(module, port=0): return client_sock, srv_sock -def create_tcp_socket(module): +def create_tcp_socket(module: ModuleType) -> Socket: """Create a TCP socket with the CLOEXEC flag set.""" type_ = module.SOCK_STREAM if hasattr(module, "SOCK_CLOEXEC"): # pragma: nocover # if available, set cloexec flag during socket creation type_ |= module.SOCK_CLOEXEC - sock = module.socket(module.AF_INET, type_) + sock: Socket = module.socket(module.AF_INET, type_) _set_default_tcpsock_options(module, sock) return sock def create_tcp_connection( - module, - address, - hostname=None, - timeout=None, - use_ssl=False, - ca=None, - certfile=None, - keyfile=None, - keyfile_password=None, - verify_certs=True, - check_hostname=False, - options=None, - ciphers=None, -): + module: ModuleType, + address: tuple[str, str | int], + hostname: str | None = None, + timeout: float | None = None, + use_ssl: bool = False, + ca: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + verify_certs: bool = True, + check_hostname: bool = False, + options: ssl.Options | None = None, + ciphers: str | None = None, +) -> Socket: end = None if timeout is None: # thanks to create_connection() developers for @@ -215,7 +242,7 @@ def create_tcp_connection( timeout = module.getdefaulttimeout() if timeout is not None: end = time.monotonic() + timeout - sock = None + sock: Socket | None = None while True: timeout_at = end if end is None else end - time.monotonic() @@ -279,7 +306,13 @@ def create_tcp_connection( sock = module.create_connection(address, timeout_at) break except Exception as ex: - errnum = ex.errno if isinstance(ex, OSError) else ex[0] + # Seriously WTF? if ex is an exception, how can it be a tuple? + # I guess gevent can do this, but really... + errnum = ( + ex.errno + if isinstance(ex, OSError) + else ex[0] # type: ignore + ) if errnum == errno.EINTR: continue raise @@ -291,7 +324,16 @@ def create_tcp_connection( return sock -def capture_exceptions(async_result): +CapturedResult = TypeVar("CapturedResult") +GenericArgs = ParamSpec("GenericArgs") + + +def capture_exceptions( + async_result: IAsyncResult, +) -> Callable[ + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], +]: """Return a new decorated function that propagates the exceptions of the wrapped function to an async_result. @@ -299,20 +341,30 @@ def capture_exceptions(async_result): """ - def capture(function): + def capture( + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @functools.wraps(function) - def captured_function(*args, **kwargs): + def captured_function( + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs + ) -> CapturedResult | None: try: return function(*args, **kwargs) except Exception as exc: async_result.set_exception(exc) + return None return captured_function return capture -def wrap(async_result): +def wrap( + async_result: IAsyncResult, +) -> Callable[ + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], +]: """Return a new decorated function that propagates the return value or exception of wrapped function to an async_result. NOTE: Only propagates a non-None return value. @@ -321,9 +373,13 @@ def wrap(async_result): """ - def capture(function): + def capture( + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @capture_exceptions(async_result) - def captured_function(*args, **kwargs): + def captured_function( + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs + ) -> CapturedResult | None: value = function(*args, **kwargs) if value is not None: async_result.set(value) @@ -334,7 +390,7 @@ def captured_function(*args, **kwargs): return capture -def fileobj_to_fd(fileobj): +def fileobj_to_fd(fileobj: FdLike) -> int: """Return a file descriptor from a file object. Parameters: @@ -349,18 +405,25 @@ def fileobj_to_fd(fileobj): if isinstance(fileobj, int): fd = fileobj else: + # FIXME given the protocol I don't think the try/catch/int are + # required. try: fd = int(fileobj.fileno()) except (AttributeError, TypeError, ValueError): raise TypeError("Invalid file object: " "{!r}".format(fileobj)) + # FIXME Questionable, just let select deal with it. if fd < 0: raise TypeError("Invalid file descriptor: {}".format(fd)) return fd def selector_select( - rlist, wlist, xlist, timeout=None, selectors_module=selectors -): + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], + timeout: float | None = None, + selectors_module: ModuleType = selectors, +) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: """Selector-based drop-in replacement for select to overcome select limitation on a maximum filehandle value. """ @@ -374,8 +437,8 @@ def selector_select( selectors_module.EVENT_READ: rlist, selectors_module.EVENT_WRITE: wlist, } - fd_events = defaultdict(int) - fd_fileobjs = defaultdict(list) + fd_events: defaultdict[int, int] = defaultdict(int) + fd_fileobjs: defaultdict[int, list[FdLike]] = defaultdict(list) for event, fileobjs in events_mapping.items(): for fileobj in fileobjs: @@ -391,7 +454,9 @@ def selector_select( # gevent can raise OSError raise ValueError("Invalid event mask or fd") from e - revents, wevents, xevents = [], [], [] + revents: list[FdLike] = [] + wevents: list[FdLike] = [] + xevents: list[FdLike] = [] try: ready = selector.select(timeout) finally: diff --git a/kazoo/hosts.py b/kazoo/hosts.py index 3ece9318..cda746a3 100644 --- a/kazoo/hosts.py +++ b/kazoo/hosts.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import urllib.parse -def collect_hosts(hosts): +def collect_hosts( + hosts: str | list[str], +) -> tuple[list[tuple[str, int]], str | None]: """ Collect a set of hosts and an optional chroot from a string or a list of strings. @@ -12,8 +16,8 @@ def collect_hosts(hosts): else: host_ports, chroot = hosts, None else: - host_ports, chroot = hosts.partition("/")[::2] - host_ports = host_ports.split(",") + host_ports_1, chroot = hosts.partition("/")[::2] + host_ports = host_ports_1.split(",") chroot = "/" + chroot if chroot else None result = [] diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index 351f1fd8..18688ee5 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -8,10 +8,160 @@ """ +from __future__ import annotations + +import abc +import queue + +from types import TracebackType +from typing import ( + Any, + Callable, + Iterable, + Protocol, + Union, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from kazoo.protocol.states import Callback + # public API -class IHandler(object): +class HasFileNo(Protocol): + """Protocol for objects that support a fileno method.""" + + def fileno(self) -> int: + ... + + +FdLike = Union[int, HasFileNo] + + +class Socket(HasFileNo, Protocol): + """This is for things that provide a socket.socket-like interface. + + This is required because: + 1. The socket in gevent doesn't inherit from socket.socket + 2. mypy gets confused if you have a method called socket and + subsequently attempt to use socket or socket.socket as a return type + """ + + def close(self) -> None: + ... + + def fileno(self) -> int: + ... + + def getpeername(self) -> tuple[str, int]: + ... + + def getsockname(self) -> tuple[str, int]: + ... + + def recv(self, bufsize: int, flags: int = 0) -> bytes: + ... + + def send(self, data: bytes | memoryview, flags: int = 0) -> int: + ... + + def sendall(self, data: bytes, flags: int = 0) -> None: + ... + + def setblocking(self, flags: bool) -> None: + ... + + def setsockopt(self, level: int, optname: int, value: int) -> None: + ... + + def shutdown(self, flag: int) -> None: + ... + + +class Lockable(Protocol): + """This is what threading.Lock implements. + + In python 3.9+ it's available natively. Though given it has some + very odd typing, I wouldn't put money on it. + """ + + def __enter__(self) -> None: + ... + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> int | None: + """The gevent release returns an int...""" + ... + + def locked(self) -> bool: + ... + + +class ReentrantLock(Protocol): + """This is what threading.RLock implements. + + In python 3.14+, it's the same as Lock, which adds to the fun. + """ + + def __enter__(self) -> None: + ... + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> None: + ... + + +class Event(Protocol): + """Protocol for threading.Event""" + + def is_set(self) -> bool: + ... + + def set(self) -> None: + ... + + def clear(self) -> None: + ... + + def wait(self, timeout: float | None = None) -> bool: + ... + + +class Threadlike(Protocol): + """Protocol for something like a thread.""" + + def is_alive(self) -> bool: + ... + + def join(self, timeout: float | None = None) -> None: + ... + + +SpawnedFunc = Callable[..., None] + + +class IHandler(abc.ABC): """A Callback Handler for Zookeeper completion and watch callbacks. This object must implement several methods responsible for @@ -44,43 +194,69 @@ class IHandler(object): """ - def start(self): + timeout_exception: type[Exception] = None # type: ignore[assignment] + sleep_func: staticmethod[[float], None] = None # type: ignore[assignment] + queue_impl: type[queue.Queue[Any]] = None # type: ignore[assignment] + + @abc.abstractmethod + def start(self) -> None: """Start the handler, used for setting up the handler.""" - def stop(self): + @abc.abstractmethod + def stop(self) -> None: """Stop the handler. Should block until the handler is safely stopped.""" - def select(self): + @abc.abstractmethod + def select( + self, + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], + timeout: float | None = None, + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: """A select method that implements Python's select.select API""" - def socket(self): - """A socket method that implements Python's socket.socket + @abc.abstractmethod + def socket(self) -> Socket: + """A socket method that implements Python's socket.socket API""" + + # FIXME This should have a proper set of parameters. + @abc.abstractmethod + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + """A socket method that implements Python's socket.create_connection API""" - def create_connection(self): - """A socket method that implements Python's - socket.create_connection API""" + @abc.abstractmethod + def create_socket_pair(self) -> tuple[Socket, Socket]: + """A socket method that implements Python's socket.socketpair API""" - def event_object(self): + @abc.abstractmethod + def event_object(self) -> Event: """Return an appropriate object that implements Python's threading.Event API""" - def lock_object(self): + @abc.abstractmethod + def lock_object(self) -> Lockable: """Return an appropriate object that implements Python's threading.Lock API""" - def rlock_object(self): + @abc.abstractmethod + def rlock_object(self) -> ReentrantLock: """Return an appropriate object that implements Python's threading.RLock API""" - def async_result(self): + @abc.abstractmethod + def async_result(self) -> IAsyncResult: """Return an instance that conforms to the :class:`~IAsyncResult` interface appropriate for this handler""" - def spawn(self, func, *args, **kwargs): + @abc.abstractmethod + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> Threadlike: """Spawn a function to run asynchronously :param args: args to call the function with. @@ -91,7 +267,8 @@ def spawn(self, func, *args, **kwargs): """ - def dispatch_callback(self, callback): + @abc.abstractmethod + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object :param callback: A :class:`~kazoo.protocol.states.Callback` @@ -100,7 +277,7 @@ def dispatch_callback(self, callback): """ -class IAsyncResult(object): +class IAsyncResult(abc.ABC): """An Async Result object that can be queried for a value that has been set asynchronously. @@ -123,15 +300,18 @@ class IAsyncResult(object): """ - def ready(self): + @abc.abstractmethod + def ready(self) -> bool: """Return `True` if and only if it holds a value or an exception""" - def successful(self): + @abc.abstractmethod + def successful(self) -> bool: """Return `True` if and only if it is ready and holds a value""" - def set(self, value=None): + @abc.abstractmethod + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters. :param value: Value to store as the result. @@ -140,7 +320,8 @@ def set(self, value=None): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def set_exception(self, exception): + @abc.abstractmethod + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters. :param exception: Exception to raise when fetching the value. @@ -149,7 +330,8 @@ def set_exception(self, exception): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def get(self, block=True, timeout=None): + @abc.abstractmethod + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception :param block: Whether this method should block or return @@ -164,13 +346,15 @@ def get(self, block=True, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def get_nowait(self): + @abc.abstractmethod + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raise the Timeout exception class on the associated :class:`IHandler` interface.""" - def wait(self, timeout=None): + @abc.abstractmethod + def wait(self, timeout: float | None = None) -> Any: """Block until the instance is ready. :param timeout: How long to wait for a value when `block` is @@ -182,7 +366,8 @@ def wait(self, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def rawlink(self, callback): + @abc.abstractmethod + def rawlink(self, callback: Callable[[IAsyncResult], Any]) -> None: """Register a callback to call when a value or an exception is set @@ -194,10 +379,17 @@ def rawlink(self, callback): """ - def unlink(self, callback): + @abc.abstractmethod + def unlink(self, callback: Callable[[IAsyncResult], None]) -> None: """Remove the callback set by :meth:`rawlink` :param callback: A callback function to remove. :type callback: func """ + + @property + @abc.abstractmethod + def exception(self) -> Exception | None: + """The exception set by :meth:`set_exception` or `None` if no + exception has been set""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 3df7b162..e4ed1a17 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -1,4 +1,7 @@ """Zookeeper Protocol Connection Handler""" + +from __future__ import annotations + from binascii import hexlify from contextlib import contextmanager import copy @@ -8,6 +11,17 @@ import socket import ssl import time +from typing import ( + Callable, + ContextManager, + Iterator, + Literal, + TypeVar, + TYPE_CHECKING, + cast, + overload, +) +from typing_extensions import Buffer from kazoo.exceptions import ( AuthFailedError, @@ -41,9 +55,14 @@ ) from kazoo.retry import ( ForceRetryError, + KazooRetry, RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import Socket, Threadlike + try: import puresasl import puresasl.client @@ -76,7 +95,7 @@ # removed from Python3+ -def buffer(obj, offset=0): +def buffer(obj: Buffer, offset: int = 0) -> memoryview: return memoryview(obj)[offset:] @@ -94,22 +113,32 @@ class RWPinger(object): """ - def __init__(self, hosts, connection_func, socket_handling): + def __init__( + self, + hosts: list[tuple[str, int]], + connection_func: Callable[[tuple[str, int]], Socket], + socket_handling: Callable[[], ContextManager[None]], + ): self.hosts = hosts self.connection = connection_func - self.last_attempt = None + self.last_attempt: float | None = None self.socket_handling = socket_handling - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, int] | Literal[False] | None]: if not self.last_attempt: self.last_attempt = time.monotonic() delay = 0.5 while True: yield self._next_server(delay) - def _next_server(self, delay): + def _next_server( + self, delay: float + ) -> tuple[str, int] | Literal[False] | None: jitter = random.randint(0, 100) / 100.0 - while time.monotonic() < self.last_attempt + delay + jitter: + while ( + time.monotonic() + < self.last_attempt + delay + jitter # type: ignore[operator] + ): # Skip rw ping checks if its too soon return False for host, port in self.hosts: @@ -128,20 +157,36 @@ def _next_server(self, delay): except ConnectionDropped: return False + # NOTE: This does actually look like it's unreachable but I don't + # want to alter the code any more than necessary for the first + # pass. + # The loop is basically a sleep with jitter that can be # Add some jitter between host pings - while time.monotonic() < self.last_attempt + jitter: + while ( # type: ignore[unreachable] + time.monotonic() < self.last_attempt + jitter + ): return False - delay *= 2 + delay *= 2 # And while not unreachable, this is pointless + return None class RWServerAvailable(Exception): """Thrown if a RW Server becomes available""" +ReturnValue = TypeVar("ReturnValue") + + class ConnectionHandler(object): """Zookeeper connection handler""" - def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): + def __init__( + self, + client: KazooClient, + retry_sleeper: KazooRetry, + logger: logging.Logger | None = None, + sasl_options: dict[str, str] | None = None, + ): self.client = client self.handler = client.handler self.retry_sleeper = retry_sleeper @@ -154,15 +199,17 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.connection_stopped.set() self.ping_outstanding = client.handler.event_object() - self._read_sock = None - self._write_sock = None + self._read_sock: Socket | None = None + self._write_sock: Socket | None = None - self._socket = None - self._xid = None - self._rw_server = None - self._ro_mode = False + self._socket: Socket | None = None + self._xid: int | None = None + self._rw_server: tuple[str, int] | None = None + self._ro_mode: Iterator[ + Literal[False] | tuple[str, int] | None + ] | Literal[False] | None = False - self._connection_routine = None + self._connection_routine: Threadlike | None = None self.sasl_options = sasl_options self.sasl_cli = None @@ -170,14 +217,14 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): # This is instance specific to avoid odd thread bug issues in Python # during shutdown global cleanup @contextmanager - def _socket_error_handling(self): + def _socket_error_handling(self) -> Iterator[None]: try: yield except (socket.error, select.error) as e: err = getattr(e, "strerror", e) raise ConnectionDropped("socket connection error: %s" % (err,)) - def start(self): + def start(self) -> None: """Start the connection up""" if self.connection_closed.is_set(): rw_sockets = self.handler.create_socket_pair() @@ -189,7 +236,7 @@ def start(self): ) self._connection_routine = self.handler.spawn(self.zk_loop) - def stop(self, timeout=None): + def stop(self, timeout: float | None = None) -> bool: """Ensure the writer has stopped, wait to see if it does.""" self.connection_stopped.wait(timeout) if self._connection_routine: @@ -197,7 +244,7 @@ def stop(self, timeout=None): self._connection_routine = None return self.connection_stopped.is_set() - def close(self): + def close(self) -> None: """Release resources held by the connection The connection can be restarted afterwards. @@ -212,7 +259,7 @@ def close(self): if rs is not None: rs.close() - def _server_pinger(self): + def _server_pinger(self) -> RWPinger: """Returns a server pinger iterable, that will ping the next server in the list, and apply a back-off between attempts.""" return RWPinger( @@ -221,16 +268,21 @@ def _server_pinger(self): self._socket_error_handling, ) - def _read_header(self, timeout): + def _read_header( + self, timeout: float | None + ) -> tuple[ReplyHeader, bytes, int]: b = self._read(4, timeout) length = int_struct.unpack(b)[0] b = self._read(length, timeout) header, offset = ReplyHeader.deserialize(b, 0) return header, b, offset - def _read(self, length, timeout): + def _read(self, length: int, timeout: float | None) -> bytes: msgparts = [] remaining = length + # We know that self._socket is not None here because we only call + # this method when we have set up the connection and the read socket. + # But mypy doesn't understand that. with self._socket_error_handling(): while remaining > 0: # Because of SSL framing, a select may not return when using @@ -240,11 +292,13 @@ def _read(self, length, timeout): # data from the underlying socket. if ( hasattr(self._socket, "pending") - and self._socket.pending() > 0 + and self._socket.pending() > 0 # type: ignore[union-attr] ): pass else: - s = self.handler.select([self._socket], [], [], timeout)[0] + s = self.handler.select( + [cast("Socket", self._socket)], [], [], timeout + )[0] if not s: # pragma: nocover # If the read list is empty, we got a timeout. We don't # have to check wlist and xlist as we don't set any @@ -252,7 +306,9 @@ def _read(self, length, timeout): "socket time-out during read" ) try: - chunk = self._socket.recv(remaining) + chunk = self._socket.recv( # type: ignore[union-attr] + remaining + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -267,7 +323,24 @@ def _read(self, length, timeout): remaining -= len(chunk) return b"".join(msgparts) - def _invoke(self, timeout, request, xid=None): + @overload + def _invoke( + self, timeout: float | None, request: Connect + ) -> tuple[Connect, int | None]: + ... + + @overload + def _invoke( + self, timeout: float | None, request: Auth, xid: int + ) -> int | None: + ... + + def _invoke( + self, + timeout: float | None, + request: Auth | Connect, + xid: int | None = None, + ) -> tuple[Connect, int | None] | int | None: """A special writer used during connection establishment only""" self._submit(request, timeout, xid) @@ -296,7 +369,9 @@ def _invoke(self, timeout, request, xid=None): if hasattr(request, "deserialize"): try: - obj, _ = request.deserialize(msg, 0) + # This is a bit of an annoying ignore as I've just done a + # hasattr... + obj, _ = request.deserialize(msg, 0) # type:ignore[union-attr] except Exception: self.logger.exception( "Exception raised during deserialization " @@ -311,7 +386,12 @@ def _invoke(self, timeout, request, xid=None): return zxid - def _submit(self, request, timeout, xid=None): + def _submit( + self, + request: Auth | Connect | Ping | SASL, + timeout: float | None, + xid: int | None = None, + ) -> None: """Submit a request object with a timeout value and optional xid""" b = bytearray() @@ -328,13 +408,17 @@ def _submit(self, request, timeout, xid=None): ) self._write(int_struct.pack(len(b)) + b, timeout) - def _write(self, msg, timeout): + def _write(self, msg: bytes, timeout: float | None) -> None: """Write a raw msg to the socket""" sent = 0 msg_length = len(msg) + # Note: The casts/type: ignore are because mypy can't work out + # self._socket is not None, and I don't want to change any code. with self._socket_error_handling(): while sent < msg_length: - s = self.handler.select([], [self._socket], [], timeout)[1] + s = self.handler.select( + [], [cast("Socket", self._socket)], [], timeout + )[1] if not s: # pragma: nocover # If the write list is empty, we got a timeout. We don't # have to check rlist and xlist as we don't set any @@ -343,7 +427,9 @@ def _write(self, msg, timeout): ) msg_slice = buffer(msg, sent) try: - bytes_sent = self._socket.send(msg_slice) + bytes_sent = self._socket.send( # type:ignore[union-attr] + msg_slice + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -356,14 +442,14 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent - def _read_watch_event(self, buffer, offset): + def _read_watch_event(self, buffer: bytes, offset: int) -> None: client = self.client watch, offset = Watch.deserialize(buffer, offset) path = watch.path self.logger.debug("Received EVENT: %s", watch) - watchers = [] + watchers: list[WatchFunc] = [] if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) @@ -385,10 +471,15 @@ def _read_watch_event(self, buffer, offset): return # Dump the watchers to the watch thread - for watch in watchers: - client.handler.dispatch_callback(Callback("watch", watch, (ev,))) - - def _read_response(self, header, buffer, offset): + for watch1 in watchers: + client.handler.dispatch_callback(Callback("watch", watch1, (ev,))) + + def _read_response( + self, + header: ReplyHeader, + buffer: bytes, + offset: int, + ) -> object | None: client = self.client request, async_object, xid = client._pending.popleft() if header.zxid and header.zxid > 0: @@ -404,7 +495,11 @@ def _read_response(self, header, buffer, offset): # Determine if its an exists request and a no node error exists_error = ( - header.err == NoNodeError.code and request.type == Exists.type + # NoNodeError does actually have a code. It's added by a decorator, + # which could possibly be better done via inheritance but this is + # less invasive to the existing code. + header.err == NoNodeError.code # type: ignore[attr-defined] + and request.type == Exists.type ) # Set the exception if its not an exists error @@ -430,7 +525,7 @@ def _read_response(self, header, buffer, offset): request, ) async_object.set_exception(exc) - return + return None self.logger.debug( "Received response(xid=%s): %r", xid, response ) @@ -452,8 +547,9 @@ def _read_response(self, header, buffer, offset): if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") return CLOSE_RESPONSE + return None - def _read_socket(self, read_timeout): + def _read_socket(self, read_timeout: float) -> object | None: """Called when there's something to read on the socket""" client = self.client @@ -476,8 +572,13 @@ def _read_socket(self, read_timeout): self.logger.log(BLATHER, "Reading for header %r", header) return self._read_response(header, buffer, offset) + return None - def _send_request(self, read_timeout, connect_timeout): + def _send_request( + self, + read_timeout: float, + connect_timeout: float, + ) -> None: """Called when we have something to send out on the socket""" client = self.client try: @@ -489,7 +590,11 @@ def _send_request(self, read_timeout, connect_timeout): try: # Clear possible inconsistence (no request in the queue # but have data in the read socket), which causes cpu to spin. - self._read_sock.recv(1) + # + # We know _read_sock is not None because we only call this + # method when we have set up the connection and the read + # socket, but mypy doesn't understand that. + self._read_sock.recv(1) # type: ignore[union-attr] except OSError: pass return @@ -505,15 +610,19 @@ def _send_request(self, read_timeout, connect_timeout): if request.type == Auth.type: xid = AUTH_XID else: - self._xid = (self._xid % 2147483647) + 1 + # We must have initialised the xid counter by now + # Might want to consider initialising it to 0 instead of none? + self._xid = (self._xid % 2147483647) + 1 # type: ignore[operator] xid = self._xid self._submit(request, connect_timeout, xid) client._queue.popleft() - self._read_sock.recv(1) + # _read_sock should never be None here as we only call this method + # when we have set up the connection and the read socket. + self._read_sock.recv(1) # type: ignore[union-attr] client._pending.append((request, async_object, xid)) - def _send_ping(self, connect_timeout): + def _send_ping(self, connect_timeout: float) -> None: self.ping_outstanding.set() self._submit(PingInstance, connect_timeout, PING_XID) @@ -524,7 +633,7 @@ def _send_ping(self, connect_timeout): self._rw_server = result raise RWServerAvailable() - def zk_loop(self): + def zk_loop(self) -> None: """Main Zookeeper handling loop""" self.logger.log(BLATHER, "ZK loop started") @@ -546,16 +655,28 @@ def zk_loop(self): self.client._session_callback(KeeperState.CLOSED) self.logger.log(BLATHER, "Connection stopped") - def _expand_client_hosts(self): + def _expand_client_hosts(self) -> list[tuple[str, str, int]]: # Expand the entire list in advance so we can randomize it if needed - host_ports = [] + host_ports: list[tuple[str, str, int]] = [] for host, port in self.client.hosts: try: host = host.strip() for rhost in socket.getaddrinfo( host, port, 0, 0, socket.IPPROTO_TCP ): - host_ports.append((host, rhost[4][0], rhost[4][1])) + # FIXME These casts seem to be unnecessary on later + # versions of mypy/python + host_ports.append( + ( + host, + cast( # type: ignore[redundant-cast] + "str", rhost[4][0] + ), + cast( # type: ignore[redundant-cast] + "int", rhost[4][1] + ), + ) + ) except socket.gaierror as e: # Skip hosts that don't resolve self.logger.warning("Cannot resolve %s: %s", host, e) @@ -564,7 +685,7 @@ def _expand_client_hosts(self): random.shuffle(host_ports) return host_ports - def _connect_loop(self, retry): + def _connect_loop(self, retry: KazooRetry) -> object: # Iterate through the hosts a full cycle before starting over status = None host_ports = self._expand_client_hosts() @@ -586,7 +707,13 @@ def _connect_loop(self, retry): else: raise ForceRetryError("Reconnecting") - def _connect_attempt(self, host, hostip, port, retry): + def _connect_attempt( + self, + host: str, + hostip: str, + port: int, + retry: KazooRetry, + ) -> object: client = self.client KazooTimeoutError = self.handler.timeout_exception @@ -606,6 +733,9 @@ def _connect_attempt(self, host, hostip, port, retry): try: self._xid = 0 read_timeout, connect_timeout = self._connect(host, hostip, port) + # I think the above implies self._socket can't be none, and + # self._read_sock is set up in start but mypy can't tell that. + # Hence the casting. read_timeout = read_timeout / 1000.0 connect_timeout = connect_timeout / 1000.0 retry.reset() @@ -619,7 +749,14 @@ def _connect_attempt(self, host, hostip, port, retry): # Ensure our timeout is positive timeout = max([deadline - time.monotonic(), jitter_time]) s = self.handler.select( - [self._socket, self._read_sock], [], [], timeout + [ + # FIXME we should know these aren't None + cast("Socket", self._socket), + cast("Socket", self._read_sock), + ], + [], + [], + timeout, )[0] if not s: @@ -629,14 +766,14 @@ def _connect_attempt(self, host, hostip, port, retry): "outstanding heartbeat ping not received" ) else: - if self._socket in s: + if cast("Socket", self._socket) in s: response = self._read_socket(read_timeout) if response == CLOSE_RESPONSE: break # Check if any requests need sending before proceeding # to process more responses. Otherwise the responses # may choke out the requests. See PR#633. - if self._read_sock in s: + if cast("Socket", self._read_sock) in s: self._send_request(read_timeout, connect_timeout) # Requests act as implicit pings. last_send = time.monotonic() @@ -674,9 +811,18 @@ def _connect_attempt(self, host, hostip, port, retry): raise finally: if self._socket is not None: - self._socket.close() - - def _connect(self, host, hostip, port): + # I think this is a bug in mypy, as the socket does get set up + # in self._connect, but it doesn't seem to be able to track + # that. + self._socket.close() # type: ignore[unreachable] + return None + + def _connect( + self, + host: str, + hostip: str, + port: int, + ) -> tuple[float, float]: client = self.client self.logger.info( "Connecting to %s(%s):%s, use_ssl: %r", @@ -707,7 +853,7 @@ def _connect(self, host, hostip, port): check_hostname=self.client.check_hostname, ) - self._socket.setblocking(0) + self._socket.setblocking(0) # type: ignore[arg-type] connect = Connect( 0, @@ -771,23 +917,36 @@ def _connect(self, host, hostip, port): return read_timeout, connect_timeout - def _authenticate_with_sasl(self, host, timeout): + def _authenticate_with_sasl(self, host: str, timeout: float) -> None: """Establish a SASL authenticated connection to the server.""" if not PURESASL_AVAILABLE: raise SASLException("Missing SASL support") - if "service" not in self.sasl_options: - self.sasl_options["service"] = "zookeeper" + # Although this can only be called if sasl_options is not None, we + # really should just have make self.sasl_options into an empty dict + # in the constructor. However, I want to avoid code changes in as + # much as possible. + if "service" not in self.sasl_options: # type: ignore[operator] + self.sasl_options["service"] = "zookeeper" # type: ignore[index] # NOTE: Zookeeper hardcoded the domain for Digest authentication # instead of using the hostname. See # zookeeper/util/SecurityUtils.java#L74 and Server/Client # initializations. - if self.sasl_options["mechanism"] == "DIGEST-MD5": + if ( + self.sasl_options["mechanism"] # type: ignore[index] + == "DIGEST-MD5" + ): host = "zk-sasl-md5" - sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( - host=host, **self.sasl_options + # I don't think the client.sasl_cli attribute is actually used + # anywhere else, so not sure why we need to set it on the client, + # but again, I want to avoid code changes as much as possible. + sasl_cli = ( + self.client.sasl_cli # type: ignore[attr-defined] + ) = puresasl.client.SASLClient( # type: ignore[no-untyped-call] + host=host, + **self.sasl_options, # type: ignore[arg-type] ) # Initialize the process with an empty challenge token diff --git a/kazoo/protocol/paths.py b/kazoo/protocol/paths.py index b8bf6650..7c47ce8a 100644 --- a/kazoo/protocol/paths.py +++ b/kazoo/protocol/paths.py @@ -1,4 +1,4 @@ -def normpath(path, trailing=False): +def normpath(path: str, trailing: bool = False) -> str: """Normalize path, eliminating double slashes, etc.""" comps = path.split("/") new_comps = [] @@ -16,7 +16,7 @@ def normpath(path, trailing=False): return new_path -def join(a, *p): +def join(a: str, *p: str) -> str: """Join two or more pathname components, inserting '/' as needed. If any component is an absolute path, all previous path components @@ -34,23 +34,23 @@ def join(a, *p): return path -def isabs(s): +def isabs(s: str) -> bool: """Test whether a path is absolute""" return s.startswith("/") -def basename(p): +def basename(p: str) -> str: """Returns the final component of a pathname""" i = p.rfind("/") + 1 return p[i:] -def _prefix_root(root, path, trailing=False): +def _prefix_root(root: str, path: str, trailing: bool = False) -> str: """Prepend a root to a path.""" return normpath( join(_norm_root(root), path.lstrip("/")), trailing=trailing ) -def _norm_root(root): +def _norm_root(root: str) -> str: return normpath(join("/", root)) diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c..29a55f35 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -1,12 +1,23 @@ -"""Zookeeper Serializers, Deserializers, and NamedTuple objects""" -from collections import namedtuple +"""Zookeeper Serializers, Deserializers, and namedtuple objects + +Note: On python3.8, you can't do classvars with NamedTuple. + +FIXME As soon as we get off python3.8 we should change the namedtuple objects +to NamedTuple, as it should get better typechecking. +""" +from __future__ import annotations + import struct +from collections import namedtuple +from typing import ClassVar, Sequence, Union, TYPE_CHECKING -from kazoo.exceptions import EXCEPTIONS +from kazoo.exceptions import EXCEPTIONS, ZookeeperError from kazoo.protocol.states import ZnodeStat from kazoo.security import ACL from kazoo.security import Id +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc # Struct objects with formats compiled bool_struct = struct.Struct("B") @@ -21,20 +32,24 @@ stat_struct = struct.Struct("!qqqqiiiqiiq") -def read_string(buffer, offset): +def read_string(buffer: bytes, offset: int) -> tuple[str, int]: """Reads an int specified buffer into a string and returns the string and the new offset in the buffer""" length = int_struct.unpack_from(buffer, offset)[0] offset += int_struct.size if length < 0: - return None, offset + # A note: write_str sends a length of -1 to indicate a value of None + # was passed. Not entirely sure where this happens because none of the + # callers of read_string seem to expect a None value. + # Should be ignoring return-value but hound cli... + return None, offset # type: ignore else: index = offset offset += length return buffer[index : index + length].decode("utf-8"), offset -def read_acl(bytes, offset): +def read_acl(bytes: bytes, offset: int) -> tuple[ACL, int]: perms = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size scheme, offset = read_string(bytes, offset) @@ -42,7 +57,7 @@ def read_acl(bytes, offset): return ACL(perms, Id(scheme, id)), offset -def write_string(bytes): +def write_string(bytes: str | None) -> bytes: if not bytes: return int_struct.pack(-1) else: @@ -50,14 +65,14 @@ def write_string(bytes): return int_struct.pack(len(utf8_str)) + utf8_str -def write_buffer(bytes): +def write_buffer(bytes: bytes | None) -> bytes: if bytes is None: return int_struct.pack(-1) else: return int_struct.pack(len(bytes)) + bytes -def read_buffer(bytes, offset): +def read_buffer(bytes: bytes, offset: int) -> tuple[bytes | None, int]: length = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if length < 0: @@ -69,10 +84,10 @@ def read_buffer(bytes, offset): class Close(namedtuple("Close", "")): - type = -11 + type: ClassVar[int] = -11 @classmethod - def serialize(cls): + def serialize(self) -> bytes: return b"" @@ -80,10 +95,10 @@ def serialize(cls): class Ping(namedtuple("Ping", "")): - type = 11 + type: ClassVar[int] = 11 @classmethod - def serialize(cls): + def serialize(cls) -> bytes: return b"" @@ -97,9 +112,16 @@ class Connect( " time_out session_id passwd read_only", ) ): - type = None + protocol_version: int + last_zxid_seen: int + time_out: int + session_id: int + passwd: bytes + read_only: bool - def serialize(self): + type: int | None = None # Note: Not a classvar + + def serialize(self) -> bytearray: b = bytearray() b.extend( int_long_int_long_struct.pack( @@ -114,7 +136,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Connect, int]: proto_version, timeout, session_id = int_int_long_struct.unpack_from( bytes, offset ) @@ -133,9 +155,14 @@ def deserialize(cls, bytes, offset): class Create(namedtuple("Create", "path data acl flags")): - type = 1 + path: str + data: bytes | None + acl: Sequence[ACL] + flags: int + + type: ClassVar[int] = 1 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -150,59 +177,74 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> str: return read_string(bytes, offset)[0] class Delete(namedtuple("Delete", "path version")): - type = 2 + path: str + version: int - def serialize(self): + type: ClassVar[int] = 2 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b @classmethod - def deserialize(self, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> bool: return True class Exists(namedtuple("Exists", "path watcher")): - type = 3 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 3 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat | None: + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return stat if stat.czxid != -1 else None class GetData(namedtuple("GetData", "path watcher")): - type = 4 + path: str + watcher: WatchFunc | None - def serialize(self): + type: ClassVar[int] = 4 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class SetData(namedtuple("SetData", "path data version")): - type = 5 + path: str + data: bytes | None + version: int + + type: ClassVar[int] = 5 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -210,18 +252,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetACL(namedtuple("GetACL", "path")): - type = 6 + path: str - def serialize(self): + type: ClassVar[int] = 6 + + def serialize(self) -> bytearray: return bytearray(write_string(self.path)) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[list[ACL], ZnodeStat] | list[ACL]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -231,14 +277,18 @@ def deserialize(cls, bytes, offset): for c in range(count): acl, offset = read_acl(bytes, offset) acls.append(acl) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return acls, stat class SetACL(namedtuple("SetACL", "path acls version")): - type = 7 + path: str + acls: Sequence[ACL] + version: int + + type: ClassVar[int] = 7 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(len(self.acls))) @@ -252,21 +302,24 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetChildren(namedtuple("GetChildren", "path watcher")): - type = 8 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 8 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -280,54 +333,71 @@ def deserialize(cls, bytes, offset): class Sync(namedtuple("Sync", "path")): - type = 9 + path: str - def serialize(self): + type: ClassVar[int] = 9 + + def serialize(self) -> bytes: return write_string(self.path) @classmethod - def deserialize(cls, buffer, offset): + def deserialize(cls, buffer: bytes, offset: int) -> str: return read_string(buffer, offset)[0] class GetChildren2(namedtuple("GetChildren2", "path watcher")): - type = 12 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 12 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[list[str], ZnodeStat] | list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover return [] - children = [] + children: list[str] = [] for c in range(count): child, offset = read_string(bytes, offset) children.append(child) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return children, stat class CheckVersion(namedtuple("CheckVersion", "path version")): - type = 13 + path: str + version: int + + type: ClassVar[int] = 13 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b +# FIXME Transaction class should move after Create2 +Transaction_Types = Union[Create, "Create2", Delete, SetData, CheckVersion] +Transaction_Response = Union[str, bool, ZnodeStat, ZookeeperError, None] + + class Transaction(namedtuple("Transaction", "operations")): - type = 14 + operations: list[Transaction_Types] + + type: ClassVar[int] = 14 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() for op in self.operations: b.extend( @@ -336,19 +406,19 @@ def serialize(self): return b + multiheader_struct.pack(-1, True, -1) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> list[Transaction_Response]: header = MultiHeader(None, False, None) - results = [] - response = None + results: list[Transaction_Response] = [] + response: Transaction_Response = None while not header.done: if header.type == Create.type: response, offset = read_string(bytes, offset) elif header.type == Delete.type: response = True elif header.type == SetData.type: - response = ZnodeStat._make( - stat_struct.unpack_from(bytes, offset) - ) + response = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) offset += stat_struct.size elif header.type == CheckVersion.type: response = True @@ -362,8 +432,10 @@ def deserialize(cls, bytes, offset): return results @staticmethod - def unchroot(client, response): - resp = [] + def unchroot( + client: KazooClient, response: list[Transaction_Response] + ) -> list[Transaction_Response]: + resp: list[Transaction_Response] = [] for result in response: if isinstance(result, str): resp.append(client.unchroot(result)) @@ -373,9 +445,14 @@ def unchroot(client, response): class Create2(namedtuple("Create2", "path data acl flags")): - type = 15 + path: str + data: bytes | None + acl: Sequence[ACL] + flags: int - def serialize(self): + type: ClassVar[int] = 15 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -390,18 +467,23 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[str, ZnodeStat]: path, offset = read_string(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return path, stat class Reconfig( namedtuple("Reconfig", "joining leaving new_members config_id") ): - type = 16 + joining: str | None + leaving: str | None + new_members: str | None + config_id: int + + type: ClassVar[int] = 16 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.joining)) b.extend(write_string(self.leaving)) @@ -410,16 +492,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class Auth(namedtuple("Auth", "auth_type scheme auth")): - type = 100 + auth_type: int + scheme: str + auth: str - def serialize(self): + type: ClassVar[int] = 100 + + def serialize(self) -> bytes: return ( int_struct.pack(self.auth_type) + write_string(self.scheme) @@ -428,22 +516,30 @@ def serialize(self): class SASL(namedtuple("SASL", "challenge")): - type = 102 + challenge: bytes | None + + type: ClassVar[int] = 102 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_buffer(self.challenge)) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, int]: challenge, offset = read_buffer(bytes, offset) return challenge, offset class Watch(namedtuple("Watch", "type state path")): + type: int + state: int + path: str + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Watch, int]: """Given bytes and the current bytes offset, return the type, state, path, and new offset""" type, state = int_int_struct.unpack_from(bytes, offset) @@ -453,19 +549,27 @@ def deserialize(cls, bytes, offset): class ReplyHeader(namedtuple("ReplyHeader", "xid, zxid, err")): + xid: int + zxid: int + err: int + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[ReplyHeader, int]: """Given bytes and the current bytes offset, return a :class:`ReplyHeader` instance and the new offset""" new_offset = offset + reply_header_struct.size return ( - cls._make(reply_header_struct.unpack_from(bytes, offset)), + cls(*reply_header_struct.unpack_from(bytes, offset)), new_offset, ) -class MultiHeader(namedtuple("MultiHeader", "type done err")): - def serialize(self): +class MultiHeader(namedtuple("MultiHeader", "type, done, err")): + type: int | None + done: bool + err: int | None + + def serialize(self) -> bytearray: b = bytearray() b.extend(int_struct.pack(self.type)) b.extend([1 if self.done else 0]) @@ -473,7 +577,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[MultiHeader, int]: t, done, err = multiheader_struct.unpack_from(bytes, offset) offset += multiheader_struct.size return cls(t, done == 1, err), offset diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e..e2dc03fd 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -1,8 +1,13 @@ """Kazoo State and Event objects""" -from collections import namedtuple +from __future__ import annotations -class KazooState(object): +from enum import Enum +from typing import Any, Callable, Iterable, NamedTuple + + +# This is a (str, Enum) for backwards compatibility. +class KazooState(str, Enum): """High level connection state values States inspired by Netflix Curator. @@ -33,7 +38,8 @@ class KazooState(object): LOST = "LOST" -class KeeperState(object): +# This is a (str, Enum) for backwards compatibility. +class KeeperState(str, Enum): """Zookeeper State Represents the Zookeeper state. Watch functions will receive a @@ -70,7 +76,8 @@ class KeeperState(object): EXPIRED_SESSION = "EXPIRED_SESSION" -class EventType(object): +# This is a (str, Enum) for backwards compatibility. +class EventType(str, Enum): """Zookeeper Event Represents a Zookeeper event. Events trigger watch functions which @@ -117,7 +124,7 @@ class EventType(object): } -class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): +class WatchedEvent(NamedTuple): """A change on ZooKeeper that a Watcher is able to respond to. The :class:`WatchedEvent` includes exactly what happened, the @@ -140,8 +147,12 @@ class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): """ + type: EventType + state: KeeperState + path: str | None + -class Callback(namedtuple("Callback", ("type", "func", "args"))): +class Callback(NamedTuple): """A callback that is handed to a handler for dispatch :param type: Type of the callback, currently is only 'watch' @@ -150,15 +161,12 @@ class Callback(namedtuple("Callback", ("type", "func", "args"))): """ + type: str + func: Callable[..., Any] + args: Iterable[Any] -class ZnodeStat( - namedtuple( - "ZnodeStat", - "czxid mzxid ctime mtime version" - " cversion aversion ephemeralOwner dataLength" - " numChildren pzxid", - ) -): + +class ZnodeStat(NamedTuple): """A ZnodeStat structure with convenience properties When getting the value of a znode from Zookeeper, the properties for @@ -216,38 +224,50 @@ class ZnodeStat( """ + czxid: int + mzxid: int + ctime: int + mtime: int + version: int + cversion: int + aversion: int + ephemeralOwner: int + dataLength: int + numChildren: int + pzxid: int + @property - def acl_version(self): + def acl_version(self) -> int: return self.aversion @property - def children_version(self): + def children_version(self) -> int: return self.cversion @property - def created(self): + def created(self) -> float: return self.ctime / 1000.0 @property - def last_modified(self): + def last_modified(self) -> float: return self.mtime / 1000.0 @property - def owner_session_id(self): + def owner_session_id(self) -> int | None: return self.ephemeralOwner or None @property - def creation_transaction_id(self): + def creation_transaction_id(self) -> int: return self.czxid @property - def last_modified_transaction_id(self): + def last_modified_transaction_id(self) -> int: return self.mzxid @property - def data_length(self): + def data_length(self) -> int: return self.dataLength @property - def children_count(self): + def children_count(self) -> int: return self.numChildren diff --git a/kazoo/py.typed b/kazoo/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/kazoo/recipe/barrier.py b/kazoo/recipe/barrier.py index 683e807b..26ffc935 100644 --- a/kazoo/recipe/barrier.py +++ b/kazoo/recipe/barrier.py @@ -4,12 +4,19 @@ :Status: Unknown """ + +from __future__ import annotations + import os import socket import uuid +from typing import Literal, TYPE_CHECKING from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent + +if TYPE_CHECKING: + from kazoo.client import KazooClient class Barrier(object): @@ -27,7 +34,7 @@ class Barrier(object): """ - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """Create a Kazoo Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -37,11 +44,11 @@ def __init__(self, client, path): self.client = client self.path = path - def create(self): + def create(self) -> None: """Establish the barrier if it doesn't exist already""" self.client.retry(self.client.ensure_path, self.path) - def remove(self): + def remove(self) -> bool: """Remove the barrier :returns: Whether the barrier actually needed to be removed. @@ -54,7 +61,7 @@ def remove(self): except NoNodeError: return False - def wait(self, timeout=None): + def wait(self, timeout: float | None = None) -> bool: """Wait on the barrier to be cleared :returns: True if the barrier has been cleared, otherwise @@ -64,7 +71,7 @@ def wait(self, timeout=None): """ cleared = self.client.handler.event_object() - def wait_for_clear(event): + def wait_for_clear(event: WatchedEvent) -> None: if event.type == EventType.DELETED: cleared.set() @@ -93,7 +100,13 @@ class DoubleBarrier(object): """ - def __init__(self, client, path, num_clients, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + num_clients: int, + identifier: str | None = None, + ): """Create a Double Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -118,7 +131,7 @@ def __init__(self, client, path, num_clients, identifier=None): self.node_name = uuid.uuid4().hex self.create_path = self.path + "/" + self.node_name - def enter(self): + def enter(self) -> None: """Enter the barrier, blocks until all nodes have entered""" try: self.client.retry(self._inner_enter) @@ -128,7 +141,7 @@ def enter(self): self._best_effort_cleanup() self.participating = False - def _inner_enter(self): + def _inner_enter(self) -> Literal[True]: # make sure our barrier parent node exists if not self.assured_path: self.client.ensure_path(self.path) @@ -145,7 +158,7 @@ def _inner_enter(self): except NodeExistsError: pass - def created(event): + def created(event: WatchedEvent) -> None: if event.type == EventType.CREATED: ready.set() @@ -159,7 +172,7 @@ def created(event): self.client.ensure_path(self.path + "/ready") return True - def leave(self): + def leave(self) -> None: """Leave the barrier, blocks until all nodes have left""" try: self.client.retry(self._inner_leave) @@ -168,7 +181,7 @@ def leave(self): self._best_effort_cleanup() self.participating = False - def _inner_leave(self): + def _inner_leave(self) -> bool: # Delete the ready node if its around try: self.client.delete(self.path + "/ready") @@ -188,7 +201,7 @@ def _inner_leave(self): ready = self.client.handler.event_object() - def deleted(event): + def deleted(event: WatchedEvent) -> None: if event.type == EventType.DELETED: ready.set() @@ -214,7 +227,7 @@ def deleted(event): # Wait for the lowest to be deleted ready.wait() - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.retry(self.client.delete, self.create_path) except NoNodeError: diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 0a22a6c7..cee96c82 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -10,20 +10,42 @@ See also: http://curator.apache.org/curator-recipes/tree-cache.html """ + +from __future__ import annotations from __future__ import absolute_import import contextlib import functools import logging import operator +from typing import ( + Any, + Callable, + Generator, + Protocol, + TypeVar, + Tuple, + TYPE_CHECKING, + Union, + overload, +) + from kazoo.exceptions import NoNodeError, KazooException from kazoo.protocol.paths import _prefix_root, join as kazoo_join -from kazoo.protocol.states import KazooState, EventType +from kazoo.protocol.states import KazooState, EventType, ZnodeStat + +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import IAsyncResult, Threadlike + from kazoo.protocol.states import WatchedEvent logger = logging.getLogger(__name__) +ReturnValue = TypeVar("ReturnValue") + + class TreeCache(object): """The cache of a ZooKeeper subtree. @@ -37,18 +59,18 @@ class TreeCache(object): _STOP = object() - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): self._client = client self._root = TreeNode.make_root(self, path) self._state = self.STATE_LATENT self._outstanding_ops = 0 self._is_initialized = False - self._error_listeners = [] - self._event_listeners = [] + self._error_listeners: list[Callable[[Exception], None]] = [] + self._event_listeners: list[Callable[[TreeEvent], None]] = [] self._task_queue = client.handler.queue_impl() - self._task_thread = None + self._task_thread: Threadlike | None = None - def start(self): + def start(self) -> None: """Starts the cache. The cache is not started automatically. You must call this method. @@ -85,7 +107,7 @@ def start(self): # without lock. self._in_background(self._root.on_created) - def close(self): + def close(self) -> None: """Closes the cache. A closed cache was detached from ZooKeeper's changes. And all nodes @@ -109,7 +131,9 @@ def close(self): # ZooKeeper actually. self._root.on_deleted() - def listen(self, listener): + def listen( + self, listener: Callable[[TreeEvent], None] + ) -> Callable[[TreeEvent], None]: """Registers a function to listen the cache events. The cache events are changes of local data. They are delivered from @@ -124,7 +148,9 @@ def listen(self, listener): self._event_listeners.append(listener) return listener - def listen_fault(self, listener): + def listen_fault( + self, listener: Callable[[Exception], None] + ) -> Callable[[Exception], None]: """Registers a function to listen the exceptions. It is possible to meet some exceptions during the cache running. You @@ -138,7 +164,9 @@ def listen_fault(self, listener): self._error_listeners.append(listener) return listener - def get_data(self, path, default=None): + def get_data( + self, path: str, default: NodeData | None = None + ) -> NodeData | None: """Gets data of a node from cache. :param path: The absolute path string. @@ -150,7 +178,9 @@ def get_data(self, path, default=None): node = self._find_node(path) return default if node is None else node._data - def get_children(self, path, default=None): + def get_children( + self, path: str, default: frozenset[str] | None = None + ) -> frozenset[str] | None: """Gets node children list from in-memory snapshot. :param path: The absolute path string. @@ -158,11 +188,14 @@ def get_children(self, path, default=None): does not exist. :raises ValueError: If the path is outside of this subtree. :returns: The :class:`frozenset` which including children names. + + # FIXME the default return value should be an empty frozenset, + # returning None is confusing. """ node = self._find_node(path) return default if node is None else frozenset(node._children) - def _find_node(self, path): + def _find_node(self, path: str) -> TreeNode | None: if not path.startswith(self._root._path): raise ValueError("outside of tree") striped_path = path[len(self._root._path) :].strip("/") @@ -170,25 +203,49 @@ def _find_node(self, path): current_node = self._root for node_name in splited_path: if node_name not in current_node._children: - return + return None current_node = current_node._children[node_name] return current_node - def _publish_event(self, event_type, event_data=None): + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: event = TreeEvent.make(event_type, event_data) if self._state != self.STATE_CLOSED: logger.debug("public event: %r", event) self._in_background(self._do_publish_event, event) - def _do_publish_event(self, event): + def _do_publish_event(self, event: TreeEvent) -> None: for listener in self._event_listeners: with handle_exception(self._error_listeners): listener(event) - def _in_background(self, func, *args, **kwargs): + @overload + def _in_background( + self, func: Callable[[TreeEvent], None], event: TreeEvent + ) -> None: + ... + + @overload + def _in_background(self, func: Callable[[], None]) -> None: + ... + + @overload + def _in_background( + self, + func: Callable[[str, str, IAsyncResult], None], + method_name: str, + path: str, + result: IAsyncResult, + ) -> None: + ... + + def _in_background( # type: ignore[misc] + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: self._task_queue.put((func, args, kwargs)) - def _do_background(self): + def _do_background(self) -> None: while True: with handle_exception(self._error_listeners): cb = self._task_queue.get() @@ -200,7 +257,7 @@ def _do_background(self): # release before possible idle del cb, func, args, kwargs - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.SUSPENDED: self._publish_event(TreeEvent.CONNECTION_SUSPENDED) elif state == KazooState.CONNECTED: @@ -212,6 +269,11 @@ def _session_watcher(self, state): self._publish_event(TreeEvent.CONNECTION_LOST) +class AsyncWatcher(Protocol): + def __call__(self, path: str, watch: WatchFunc | None) -> IAsyncResult: + ... + + class TreeNode(object): """The tree node record. @@ -234,28 +296,28 @@ class TreeNode(object): STATE_LIVE = 1 STATE_DEAD = 2 - def __init__(self, tree, path, parent): + def __init__(self, tree: TreeCache, path: str, parent: TreeNode | None): self._tree = tree self._path = path self._parent = parent - self._depth = parent._depth + 1 if parent else 0 - self._children = {} + self._depth: int = parent._depth + 1 if parent is not None else 0 + self._children: dict[str, TreeNode] = {} self._state = self.STATE_PENDING - self._data = None + self._data: NodeData | None = None @classmethod - def make_root(cls, tree, path): + def make_root(cls, tree: TreeCache, path: str) -> TreeNode: return cls(tree, path, None) - def on_reconnected(self): + def on_reconnected(self) -> None: self._refresh() for child in self._children.values(): child.on_reconnected() - def on_created(self): + def on_created(self) -> None: self._refresh() - def on_deleted(self): + def on_deleted(self) -> None: old_children, self._children = self._children, {} old_data, self._data = self._data, None @@ -278,37 +340,43 @@ def on_deleted(self): del self._parent._children[child] self._reset_watchers() - def _publish_event(self, *args, **kwargs): - return self._tree._publish_event(*args, **kwargs) + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: + return self._tree._publish_event(event_type, event_data) - def _reset_watchers(self): + def _reset_watchers(self) -> None: client = self._tree._client for _watchers in (client._data_watchers, client._child_watchers): _path = _prefix_root(client.chroot, self._path) _watcher = _watchers.get(_path, set()) _watcher.discard(self._process_watch) - def _refresh(self): + def _refresh(self) -> None: self._refresh_data() self._refresh_children() - def _refresh_data(self): + def _refresh_data(self) -> None: self._call_client("get", self._path) - def _refresh_children(self): + def _refresh_children(self) -> None: # TODO max-depth checking support self._call_client("get_children", self._path) - def _call_client(self, method_name, path): + def _call_client(self, method_name: str, path: str) -> None: assert method_name in ("get", "get_children", "exists") self._tree._outstanding_ops += 1 callback = functools.partial( self._tree._in_background, self._process_result, method_name, path ) - method = getattr(self._tree._client, method_name + "_async") + # The typing for this is really bad but the type checker can + # understand it with a few hacks + method: AsyncWatcher = getattr( + self._tree._client, method_name + "_async" + ) method(path, watch=self._process_watch).rawlink(callback) - def _process_watch(self, watched_event): + def _process_watch(self, watched_event: WatchedEvent) -> None: logger.debug("process_watch: %r", watched_event) with handle_exception(self._tree._error_listeners): if watched_event.type == EventType.CREATED: @@ -321,7 +389,9 @@ def _process_watch(self, watched_event): elif watched_event.type == EventType.CHILD: self._refresh_children() - def _process_result(self, method_name, path, result): + def _process_result( + self, method_name: str, path: str, result: IAsyncResult + ) -> None: logger.debug("process_result: %s %s", method_name, path) if method_name == "exists": assert self._parent is None, "unexpected EXISTS on non-root" @@ -332,7 +402,7 @@ def _process_result(self, method_name, path, result): self.on_created() elif method_name == "get_children": if result.successful(): - children = result.get() + children: list[str] = result.get() for child in sorted(children): full_path = kazoo_join(path, child) if child not in self._children: @@ -367,9 +437,12 @@ def _process_result(self, method_name, path, result): self._publish_event(TreeEvent.INITIALIZED) -class TreeEvent(tuple): +# FIXME these Tuple-based classes would look a lot better using NamedTuple +# though the event_type in TreeEvent needs sorting. +class TreeEvent(Tuple[int, Union["NodeData", None]]): """The immutable event tuple of cache.""" + # FIXME These should be an enum. NODE_ADDED = 0 NODE_UPDATED = 1 NODE_REMOVED = 2 @@ -385,7 +458,9 @@ class TreeEvent(tuple): event_data = property(operator.itemgetter(1)) @classmethod - def make(cls, event_type, event_data): + def make( + cls, event_type: int, event_data: NodeData | None = None + ) -> TreeEvent: """Creates a new TreeEvent tuple. :returns: A :class:`~kazoo.recipe.cache.TreeEvent` instance. @@ -402,7 +477,7 @@ def make(cls, event_type, event_data): return cls((event_type, event_data)) -class NodeData(tuple): +class NodeData(Tuple[str, bytes, ZnodeStat]): """The immutable node data tuple of cache.""" #: The absolute path string of current node. @@ -415,7 +490,7 @@ class NodeData(tuple): stat = property(operator.itemgetter(2)) @classmethod - def make(cls, path, data, stat): + def make(cls, path: str, data: bytes, stat: ZnodeStat) -> NodeData: """Creates a new NodeData tuple. :returns: A :class:`~kazoo.recipe.cache.NodeData` instance. @@ -424,7 +499,9 @@ def make(cls, path, data, stat): @contextlib.contextmanager -def handle_exception(listeners): +def handle_exception( + listeners: list[Callable[[Exception], None]], +) -> Generator[None, None, None]: try: yield except Exception as e: diff --git a/kazoo/recipe/counter.py b/kazoo/recipe/counter.py index 3b2cc339..53132252 100644 --- a/kazoo/recipe/counter.py +++ b/kazoo/recipe/counter.py @@ -4,9 +4,20 @@ :Status: Unknown """ + +from __future__ import annotations + +import struct +from typing import Union, TYPE_CHECKING + from kazoo.exceptions import BadVersionError from kazoo.retry import ForceRetryError -import struct + +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +Number = Union[int, float] class Counter(object): @@ -58,7 +69,13 @@ class Counter(object): """ - def __init__(self, client, path, default=0, support_curator=False): + def __init__( + self, + client: KazooClient, + path: str, + default: Number = 0, + support_curator: bool = False, + ): """Create a Kazoo Counter :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -74,22 +91,24 @@ def __init__(self, client, path, default=0, support_curator=False): self.default_type = type(default) self.support_curator = support_curator self._ensured_path = False - self.pre_value = None - self.post_value = None + self.pre_value: Number | None = None + self.post_value: Number | None = None if self.support_curator and not isinstance(self.default, int): raise TypeError( "when support_curator is enabled the default " "type must be an int" ) - def _ensure_node(self): + def _ensure_node(self) -> None: if not self._ensured_path: # make sure our node exists self.client.ensure_path(self.path) self._ensured_path = True - def _value(self): + def _value(self) -> tuple[Number, int]: self._ensure_node() + # This is astonishingly hard to follow... + old: Union[bytes, str, Number] old, stat = self.client.get(self.path) if self.support_curator: old = struct.unpack(">i", old)[0] if old != b"" else self.default @@ -100,16 +119,16 @@ def _value(self): return data, version @property - def value(self): + def value(self) -> Number: return self._value()[0] - def _change(self, value): + def _change(self, value: Number) -> Counter: if not isinstance(value, self.default_type): raise TypeError("invalid type for value change") self.client.retry(self._inner_change, value) return self - def _inner_change(self, value): + def _inner_change(self, value: Number) -> None: self.pre_value, version = self._value() post_value = self.pre_value + value if self.support_curator: @@ -123,10 +142,10 @@ def _inner_change(self, value): raise ForceRetryError() self.post_value = post_value - def __add__(self, value): + def __add__(self, value: Number) -> Counter: """Add value to counter.""" return self._change(value) - def __sub__(self, value): + def __sub__(self, value: Number) -> Counter: """Subtract value from counter.""" return self._change(-value) diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 93bb7258..1e28517b 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -4,8 +4,19 @@ :Status: Unknown """ + +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING +from typing_extensions import ParamSpec + from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + +GenericArgs = ParamSpec("GenericArgs") + class Election(object): """Kazoo Basic Leader Election @@ -22,7 +33,12 @@ class Election(object): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): """Create a Kazoo Leader Election :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -34,7 +50,12 @@ def __init__(self, client, path, identifier=None): """ self.lock = client.Lock(path, identifier) - def run(self, func, *args, **kwargs): + def run( + self, + func: Callable[GenericArgs, None], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, + ) -> None: """Contend for the leadership This call will block until either this contender is cancelled @@ -57,7 +78,7 @@ def run(self, func, *args, **kwargs): except CancelledError: pass - def cancel(self): + def cancel(self) -> None: """Cancel participation in the election .. note:: @@ -68,7 +89,7 @@ def cancel(self): """ self.lock.cancel() - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders in the election diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index ce7fe567..1e4c9cf8 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -5,12 +5,25 @@ :Status: Beta """ + +from __future__ import annotations + import datetime import json import socket +from typing import Callable, TypedDict, TYPE_CHECKING, cast from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +class Lease(TypedDict): + version: int + holder: str + end: str + class NonBlockingLease(object): """Exclusive lease that does not block. @@ -48,11 +61,11 @@ class NonBlockingLease(object): def __init__( self, - client, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + path: str, + duration: datetime.timedelta, + identifier: str | None = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): """Create a non-blocking lease. @@ -71,7 +84,14 @@ def __init__( self.obtained = False self._attempt_obtaining(client, path, duration, ident, utcnow) - def _attempt_obtaining(self, client, path, duration, ident, utcnow): + def _attempt_obtaining( + self, + client: KazooClient, + path: str, + duration: datetime.timedelta, + ident: str, + utcnow: Callable[[], datetime.datetime], + ) -> None: client.ensure_path(path) holder_path = path + "/lease_holder" lock = client.Lock(path, ident) @@ -92,7 +112,7 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): return client.delete(holder_path) end_lease = (now + duration).strftime(self._date_format) - new_data = { + new_data: Lease = { "version": self._version, "holder": ident, "end": end_lease, @@ -103,18 +123,13 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): except CancelledError: pass - def _encode(self, data_dict): + def _encode(self, data_dict: Lease) -> bytes: return json.dumps(data_dict).encode(self._byte_encoding) - def _decode(self, raw): - return json.loads(raw.decode(self._byte_encoding)) - - # Python 2.x - def __nonzero__(self): - return self.obtained + def _decode(self, raw: bytes) -> Lease: + return cast("Lease", json.loads(raw.decode(self._byte_encoding))) - # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained @@ -140,12 +155,12 @@ class MultiNonBlockingLease(object): def __init__( self, - client, - count, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + count: int, + path: str, + duration: datetime.timedelta, + identifier: str | None = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): self.obtained = False for num in range(count): @@ -160,10 +175,5 @@ def __init__( self.obtained = True break - # Python 2.x - def __nonzero__(self): - return self.obtained - - # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 1f524702..6f3236db 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -14,9 +14,19 @@ and/or the lease has been lost. """ + +from __future__ import annotations + import re import time import uuid +from typing import ( + Iterable, + Literal, + Pattern, + TYPE_CHECKING, +) +from types import TracebackType from kazoo.exceptions import ( CancelledError, @@ -24,27 +34,37 @@ LockTimeout, NoNodeError, ) -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent from kazoo.retry import ( ForceRetryError, KazooRetry, RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient + class _Watch(object): - def __init__(self, duration=None): + def __init__(self, duration: float | None = None): self.duration = duration - self.started_at = None + self.started_at: float | None = None - def start(self): + def start(self) -> None: self.started_at = time.monotonic() - def leftover(self): + def leftover(self) -> float | None: if self.duration is None: return None else: - elapsed = time.monotonic() - self.started_at + # We should probably set started_at to either 0 or + # time.monotonic() in __init__ to avoid the type ignore + # here, but this is a private class and it's pretty clear + # that start() should be called before leftover() so I'm + # not sure it's worth it. + elapsed = ( + time.monotonic() - self.started_at # type: ignore[operator] + ) return max(0, self.duration - elapsed) @@ -77,7 +97,13 @@ class Lock(object): # sequence number. Involved in read/write locks. _EXCLUDE_NAMES = ["__lock__"] - def __init__(self, client, path, identifier=None, extra_lock_patterns=()): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + extra_lock_patterns: Iterable[str] = (), + ): """Create a Kazoo lock. :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -97,10 +123,10 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): """ self.client = client self.path = path - self._exclude_names = set( + self._exclude_names: set[str] = set( self._EXCLUDE_NAMES + list(extra_lock_patterns) ) - self._contenders_re = re.compile( + self._contenders_re: Pattern[str] = re.compile( r"(?:{patterns})(-?\d{{10}})$".format( patterns="|".join(self._exclude_names) ) @@ -109,7 +135,7 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock self.data = str(identifier or "").encode("utf-8") - self.node = None + self.node: str | None = None self.wake_event = client.handler.event_object() @@ -129,16 +155,21 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): ) self._acquire_method_lock = client.handler.lock_object() - def _ensure_path(self): + def _ensure_path(self) -> None: self.client.ensure_path(self.path) self.assured_path = True - def cancel(self): + def cancel(self) -> None: """Cancel a pending lock acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None, ephemeral=True): + def acquire( + self, + blocking: bool = True, + timeout: float | None = None, + ephemeral: bool = True, + ) -> bool: """ Acquire the lock. By defaults blocks and waits forever. @@ -204,11 +235,16 @@ def acquire(self, blocking=True, timeout=None, ephemeral=True): finally: self._acquire_method_lock.release() - def _watch_session(self, state): + def _watch_session(self, state: KazooState) -> bool: self.wake_event.set() return True - def _inner_acquire(self, blocking, timeout, ephemeral=True): + def _inner_acquire( + self, + blocking: bool, + timeout: float | None, + ephemeral: bool = True, + ) -> bool: # wait until it's our chance to get it.. if self.is_acquired: if not blocking: @@ -219,7 +255,7 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): if not self.assured_path: self._ensure_path() - node = None + node: str | None = None if self.create_tried: node = self._find_node() else: @@ -265,10 +301,10 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): finally: self.client.remove_listener(self._watch_session) - def _watch_predecessor(self, event): + def _watch_predecessor(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_predecessor(self, node): + def _get_predecessor(self, node: str) -> str | None: """returns `node`'s predecessor or None Note: This handle the case where the current lock is not a contender @@ -277,7 +313,7 @@ def _get_predecessor(self, node): """ node_sequence = node[len(self.prefix) :] children = self.client.get_children(self.path) - found_self = False + found_self: Literal[False] | re.Match[str] | None = False # Filter out the contenders using the computed regex contender_matches = [] for child in children: @@ -308,17 +344,17 @@ def _get_predecessor(self, node): sorted_matches = sorted(contender_matches, key=lambda m: m.groups()) return sorted_matches[-1].string - def _find_node(self): + def _find_node(self) -> str | None: children = self.client.get_children(self.path) for child in children: if child.startswith(self.prefix): return child return None - def _delete_node(self, node): + def _delete_node(self, node: str) -> None: self.client.delete(self.path + "/" + node) - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: node = self.node or self._find_node() if node: @@ -326,16 +362,18 @@ def _best_effort_cleanup(self): except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lock immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: - self._delete_node(self.node) + # I don't think it's possible for self.node to be None here if + # self.is_acquired is true. + self._delete_node(self.node) # type: ignore[arg-type] except NoNodeError: # pragma: nocover pass @@ -343,7 +381,7 @@ def _inner_release(self): self.node = None return True - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders for the lock. @@ -390,11 +428,17 @@ def contenders(self): return contenders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: self.release() + return None class WriteLock(Lock): @@ -492,7 +536,13 @@ class Semaphore(object): """ - def __init__(self, client, path, identifier=None, max_leases=1): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + max_leases: int = 1, + ): """Create a Kazoo Lock :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -528,7 +578,7 @@ def __init__(self, client, path, identifier=None, max_leases=1): self.cancelled = False self._session_expired = False - def _ensure_path(self): + def _ensure_path(self) -> None: result = self.client.ensure_path(self.path) self.assured_path = True if result is True: @@ -549,12 +599,16 @@ def _ensure_path(self): else: self.client.set(self.path, str(self.max_leases).encode("utf-8")) - def cancel(self): + def cancel(self) -> None: """Cancel a pending semaphore acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None): + def acquire( + self, + blocking: bool = True, + timeout: float | None = None, + ) -> bool: """Acquire the semaphore. By defaults blocks and waits forever. :param blocking: Block until semaphore is obtained or @@ -592,7 +646,11 @@ def acquire(self, blocking=True, timeout=None): return self.is_acquired - def _inner_acquire(self, blocking, timeout=None): + def _inner_acquire( + self, + blocking: bool, + timeout: float | None = None, + ) -> bool: """Inner loop that runs from the top anytime a command hits a retryable Zookeeper exception.""" self._session_expired = False @@ -607,7 +665,12 @@ def _inner_acquire(self, blocking, timeout=None): w = _Watch(duration=timeout) w.start() - lock = self.client.Lock(self.lock_path, self.data) + # FIXME This is passing bytes data, but self.client.Lock expects a str, + # which I think is a bug in this code. However, I don't want to + # change any code at this point, so we just ignore the type error here. + lock = self.client.Lock( + self.lock_path, self.data # type: ignore[arg-type] + ) try: gotten = lock.acquire(blocking=blocking, timeout=w.leftover()) if not gotten: @@ -633,10 +696,10 @@ def _inner_acquire(self, blocking, timeout=None): finally: lock.release() - def _watch_lease_change(self, event): + def _watch_lease_change(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_lease(self, data=None): + def _get_lease(self) -> bool: # Make sure the session is still valid if self._session_expired: raise ForceRetryError("Retry on session loss at top") @@ -665,25 +728,26 @@ def _get_lease(self, data=None): # Return current state return self.is_acquired - def _watch_session(self, state): + def _watch_session(self, state: KazooState) -> bool | None: if state == KazooState.LOST: self._session_expired = True self.wake_event.set() # Return true to de-register return True + return None - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.delete(self.create_path) except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lease immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: @@ -693,7 +757,7 @@ def _inner_release(self): self.is_acquired = False return True - def lease_holders(self): + def lease_holders(self) -> list[str]: """Return an unordered list of the current lease holders. .. note:: @@ -716,8 +780,13 @@ def lease_holders(self): pass return lease_holders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.release() diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 21dc6ef4..bcc3b191 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -17,20 +17,38 @@ so that no two workers own the same queue. """ + +from __future__ import annotations + from functools import partial import logging import os import socket +from enum import Enum +from typing import ( + Callable, + Generic, + Iterator, + Iterable, + TYPE_CHECKING, + TypeVar, +) from kazoo.exceptions import KazooException, LockTimeout from kazoo.protocol.states import KazooState from kazoo.recipe.watchers import PatientChildrenWatch +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import Event, IAsyncResult + from kazoo.recipe.lock import Lock + from _typeshed import SupportsRichComparisonT log = logging.getLogger(__name__) -class PartitionState(object): +# This is a (str, Enum) for backwards compatibility. +class PartitionState(str, Enum): """High level partition state values .. attribute:: ALLOCATING @@ -63,7 +81,20 @@ class PartitionState(object): FAILURE = "FAILURE" -class SetPartitioner(object): +if TYPE_CHECKING: + # FIXME There should be a better way of doing this, but it's not clear how + # we can achieve what we want (specify that PartitionDataT is a type which + # is a SupportsRichComparisonT type) but clearly you can only do that with + # typechecking on. + PartitionDataT = TypeVar( + "PartitionDataT", + bound=SupportsRichComparisonT, # type:ignore[valid-type] + ) +else: + PartitionDataT = TypeVar("PartitionDataT") + + +class SetPartitioner(Generic[PartitionDataT]): """Partitions a set amongst members of a party This class will partition a set amongst members of a party such @@ -139,14 +170,18 @@ class SetPartitioner(object): def __init__( self, - client, - path, - set, - partition_func=None, - identifier=None, - time_boundary=30, - max_reaction_time=1, - state_change_event=None, + client: KazooClient, + path: str, + set: Iterable[PartitionDataT], + partition_func: Callable[ + [str, Iterable[str], Iterable[PartitionDataT]], + list[PartitionDataT], + ] + | None = None, + identifier: str | None = None, + time_boundary: float = 30, + max_reaction_time: float = 1, + state_change_event: Event | None = None, ): """Create a :class:`~SetPartitioner` instance @@ -176,13 +211,13 @@ def __init__( self._client = client self._path = path self._set = set - self._partition_set = [] + self._partition_set: list[PartitionDataT] = [] self._partition_func = partition_func or self._partitioner self._identifier = identifier or "%s-%s" % ( socket.getfqdn(), os.getpid(), ) - self._locks = [] + self._locks: list[Lock] = [] self._lock_path = "/".join([path, "locks"]) self._party_path = "/".join([path, "party"]) self._time_boundary = time_boundary @@ -208,33 +243,33 @@ def __init__( # so we know when we're ready self._child_watching(self._allocate_transition, client_handler=True) - def __iter__(self): + def __iter__(self) -> Iterator[PartitionDataT]: """Return the partitions in this partition set""" for partition in self._partition_set: yield partition @property - def failed(self): + def failed(self) -> bool: """Corresponds to the :attr:`PartitionState.FAILURE` state""" return self.state == PartitionState.FAILURE @property - def release(self): + def release(self) -> bool: """Corresponds to the :attr:`PartitionState.RELEASE` state""" return self.state == PartitionState.RELEASE @property - def allocating(self): + def allocating(self) -> bool: """Corresponds to the :attr:`PartitionState.ALLOCATING` state""" return self.state == PartitionState.ALLOCATING @property - def acquired(self): + def acquired(self) -> bool: """Corresponds to the :attr:`PartitionState.ACQUIRED` state""" return self.state == PartitionState.ACQUIRED - def wait_for_acquire(self, timeout=30): + def wait_for_acquire(self, timeout: float = 30) -> None: """Wait for the set to be partitioned and acquired :param timeout: How long to wait before returning. @@ -243,7 +278,7 @@ def wait_for_acquire(self, timeout=30): """ self._acquire_event.wait(timeout) - def release_set(self): + def release_set(self) -> None: """Call to release the set This method begins the step of allocating once the set has @@ -263,12 +298,12 @@ def release_set(self): self._set_state(PartitionState.ALLOCATING) self._child_watching(self._allocate_transition, client_handler=True) - def finish(self): + def finish(self) -> None: """Call to release the set and leave the party""" self._release_locks() self._fail_out() - def _fail_out(self): + def _fail_out(self) -> None: with self._state_change: self._set_state(PartitionState.FAILURE) if self._party.participating: @@ -277,7 +312,7 @@ def _fail_out(self): except KazooException: # pragma: nocover pass - def _allocate_transition(self, result): + def _allocate_transition(self, result: IAsyncResult) -> None: """Called when in allocating mode, and the children settled""" # Did we get an exception waiting for children to settle? @@ -288,7 +323,7 @@ def _allocate_transition(self, result): children, async_result = result.get() children_changed = self._client.handler.event_object() - def updated(result): + def updated(result: IAsyncResult) -> None: with self._state_change: children_changed.set() if self.acquired: @@ -307,7 +342,7 @@ def updated(result): # Check whether the state has changed during the lock acquisition # and abort the process if so. - def abort_if_needed(): + def abort_if_needed() -> bool: if self.state_id == state_id: if children_changed.is_set(): # The party has changed. Repartitioning... @@ -365,7 +400,7 @@ def abort_if_needed(): # This mustn't happen. Means a logical error. self._fail_out() - def _release_locks(self): + def _release_locks(self) -> None: """Attempt to completely remove all the locks""" self._acquire_event.clear() for lock in self._locks[:]: @@ -378,7 +413,7 @@ def _release_locks(self): else: self._locks.remove(lock) - def _abort_lock_acquisition(self): + def _abort_lock_acquisition(self) -> None: """Called during lock acquisition if a party change occurs""" self._release_locks() @@ -391,7 +426,13 @@ def _abort_lock_acquisition(self): self._child_watching(self._allocate_transition, client_handler=True) - def _child_watching(self, func=None, client_handler=False): + # FIXME This is only ever called with func=self._allocation_transition, but + # I didn't want to change the code. + def _child_watching( + self, + func: Callable[[IAsyncResult], None] | None = None, + client_handler: bool = False, + ) -> IAsyncResult: """Called when children are being watched to stabilize This actually returns immediately, child watcher spins up a @@ -410,11 +451,15 @@ def _child_watching(self, func=None, client_handler=False): # to ensure that the rawlink's it might use won't be # blocked if client_handler: - func = partial(self._client.handler.spawn, func) + # FIXME This feels wrong, but it may be because partial is + # confusing things. + func = partial( # type: ignore[assignment] + self._client.handler.spawn, func + ) asy.rawlink(func) return asy - def _establish_sessionwatch(self, state): + def _establish_sessionwatch(self, state: KazooState) -> bool: """Register ourself to listen for session events, we shut down if we become lost""" with self._state_change: @@ -427,7 +472,12 @@ def _establish_sessionwatch(self, state): return state == KazooState.LOST - def _partitioner(self, identifier, members, partitions): + def _partitioner( + self, + identifier: str, + members: Iterable[str], + partitions: Iterable[PartitionDataT], + ) -> list[PartitionDataT]: # Ensure consistent order of partitions/members all_partitions = sorted(partitions) workers = sorted(members) @@ -437,7 +487,7 @@ def _partitioner(self, identifier, members, partitions): # skipping the other workers return all_partitions[i :: len(workers)] - def _set_state(self, state): + def _set_state(self, state: PartitionState) -> None: self.state = state self.state_id += 1 self.state_change_event.set() diff --git a/kazoo/recipe/party.py b/kazoo/recipe/party.py index 2a0f5dfb..1fc1340b 100644 --- a/kazoo/recipe/party.py +++ b/kazoo/recipe/party.py @@ -7,15 +7,27 @@ used for determining members of a party. """ + +from __future__ import annotations + import uuid +from typing import Iterator, TYPE_CHECKING from kazoo.exceptions import NodeExistsError, NoNodeError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseParty(object): """Base implementation of a party.""" - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The party path to use. @@ -29,44 +41,52 @@ def __init__(self, client, path, identifier=None): self.ensured_path = False self.participating = False - def _ensure_parent(self): + def _ensure_parent(self) -> None: if not self.ensured_path: # make sure our parent node exists self.client.ensure_path(self.path) self.ensured_path = True - def join(self): + def join(self) -> None: """Join the party""" return self.client.retry(self._inner_join) - def _inner_join(self): + def _inner_join(self) -> None: self._ensure_parent() try: - self.client.create(self.create_path, self.data, ephemeral=True) + # This and the #type: ignore[attr-defined] below could be removed + # by setting up create_path in the constructor but trying to avoid + # changing the code too much. It does actually cause later versions + # of pylint to error though. + self.client.create( + self.create_path, # type: ignore[attr-defined] + self.data, + ephemeral=True, + ) self.participating = True except NodeExistsError: # node was already created, perhaps we are recovering from a # suspended connection self.participating = True - def leave(self): + def leave(self) -> bool: """Leave the party""" self.participating = False return self.client.retry(self._inner_leave) - def _inner_leave(self): + def _inner_leave(self) -> bool: try: - self.client.delete(self.create_path) + self.client.delete(self.create_path) # type: ignore[attr-defined] except NoNodeError: return False return True - def __len__(self): + def __len__(self) -> int: """Return a count of participating clients""" self._ensure_parent() return len(self._get_children()) - def _get_children(self): + def _get_children(self) -> list[str]: return self.client.retry(self.client.get_children, self.path) @@ -75,12 +95,14 @@ class Party(BaseParty): _NODE_NAME = "__party__" - def __init__(self, client, path, identifier=None): + def __init__( + self, client: KazooClient, path: str, identifier: str | None = None + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = uuid.uuid4().hex + self._NODE_NAME self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' data values""" self._ensure_parent() children = self._get_children() @@ -93,7 +115,7 @@ def __iter__(self): except NoNodeError: # pragma: nocover pass - def _get_children(self): + def _get_children(self) -> list[str]: children = BaseParty._get_children(self) return [c for c in children if self._NODE_NAME in c] @@ -109,12 +131,17 @@ class ShallowParty(BaseParty): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = "-".join([uuid.uuid4().hex, self.data.decode("utf-8")]) self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' identifiers""" self._ensure_parent() children = self._get_children() diff --git a/kazoo/recipe/queue.py b/kazoo/recipe/queue.py index 30d3066e..85a86676 100644 --- a/kazoo/recipe/queue.py +++ b/kazoo/recipe/queue.py @@ -9,17 +9,24 @@ See: https://github.com/python-zk/kazoo/issues/175 """ + +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING from kazoo.exceptions import NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent from kazoo.retry import ForceRetryError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseQueue(object): """A common base class for queue implementations.""" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. @@ -27,10 +34,10 @@ def __init__(self, client, path): self.client = client self.path = path self._entries_path = path - self.structure_paths = (self.path,) + self.structure_paths: tuple[str, ...] = (self.path,) self.ensured_path = False - def _check_put_arguments(self, value, priority=100): + def _check_put_arguments(self, value: bytes, priority: int = 100) -> None: if not isinstance(value, bytes): raise TypeError("value must be a byte string") if not isinstance(priority, int): @@ -38,14 +45,14 @@ def _check_put_arguments(self, value, priority=100): elif priority < 0 or priority > 999: raise ValueError("priority must be between 0 and 999") - def _ensure_paths(self): + def _ensure_paths(self) -> None: if not self.ensured_path: # make sure our parent / internal structure nodes exists for path in self.structure_paths: self.client.ensure_path(path) self.ensured_path = True - def __len__(self): + def __len__(self) -> int: self._ensure_paths() _, stat = self.client.retry(self.client.get, self._entries_path) return stat.children_count @@ -62,19 +69,19 @@ class Queue(BaseQueue): prefix = "entry-" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(Queue, self).__init__(client, path) - self._children = [] + self._children: list[str] = [] - def __len__(self): + def __len__(self) -> int: """Return queue size.""" return super(Queue, self).__len__() - def get(self): + def get(self) -> bytes | None: """ Get item data and remove an item from the queue. @@ -84,7 +91,7 @@ def get(self): self._ensure_paths() return self.client.retry(self._inner_get) - def _inner_get(self): + def _inner_get(self) -> bytes | None: if not self._children: self._children = self.client.retry( self.client.get_children, self.path @@ -105,7 +112,7 @@ def _inner_get(self): self._children.pop(0) return data - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an item into the queue. :param value: Byte string to put into the queue. @@ -150,26 +157,26 @@ class LockingQueue(BaseQueue): entries = "/entries" entry = "entry" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(LockingQueue, self).__init__(client, path) self.id = uuid.uuid4().hex.encode() - self.processing_element = None + self.processing_element: tuple[str, bytes] | None = None self._lock_path = self.path + self.lock self._entries_path = self.path + self.entries self.structure_paths = (self._lock_path, self._entries_path) - def __len__(self): + def __len__(self) -> int: """Returns the current length of the queue. :returns: queue size (includes locked entries count). """ return super(LockingQueue, self).__len__() - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an entry into the queue. :param value: Byte string to put into the queue. @@ -189,7 +196,7 @@ def put(self, value, priority=100): sequence=True, ) - def put_all(self, values, priority=100): + def put_all(self, values: list[bytes], priority: int = 100) -> None: """Put several entries into the queue. The action only succeeds if all entries where put into the queue. @@ -221,7 +228,7 @@ def put_all(self, values, priority=100): sequence=True, ) - def get(self, timeout=None): + def get(self, timeout: float | None = None) -> bytes | None: """Locks and gets an entry from the queue. If a previously got entry was not consumed, this method will return that entry. @@ -237,7 +244,7 @@ def get(self, timeout=None): else: return self._inner_get(timeout) - def holds_lock(self): + def holds_lock(self) -> bool: """Checks if a node still holds the lock. :returns: True if a node still holds the lock, False otherwise. @@ -251,7 +258,7 @@ def holds_lock(self): value, stat = self.client.retry(self.client.get, lock_path) return value == self.id - def consume(self): + def consume(self) -> bool: """Removes a currently processing entry from the queue. :returns: True if element was removed successfully, False otherwise. @@ -271,7 +278,7 @@ def consume(self): else: return False - def release(self): + def release(self) -> bool: """Removes the lock from currently processed item without consuming it. :returns: True if the lock was removed successfully, False otherwise. @@ -289,13 +296,13 @@ def release(self): else: return False - def _inner_get(self, timeout): + def _inner_get(self, timeout: float | None) -> bytes | None: flag = self.client.handler.event_object() lock = self.client.handler.lock_object() canceled = False value = [] - def check_for_updates(event): + def check_for_updates(event: WatchedEvent | None) -> None: if event is not None and event.type != EventType.CHILD: return with lock: @@ -330,8 +337,8 @@ def check_for_updates(event): retVal = value[0][1] return retVal - def _filter_locked(self, values, taken): - taken = set(taken) + def _filter_locked(self, values: list[str], taken: list[str]) -> list[str]: + taken = set(taken) # type: ignore[assignment] available = sorted(values) return ( available @@ -339,7 +346,7 @@ def _filter_locked(self, values, taken): else [x for x in available if x not in taken] ) - def _take(self, id_): + def _take(self, id_: str) -> tuple[str, bytes] | None: try: self.client.create( "{path}/{id}".format(path=self._lock_path, id=id_), diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index d4cb0300..77f54547 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -10,25 +10,45 @@ will result in an exception being thrown. """ + +from __future__ import annotations + from functools import partial, wraps import logging import time import warnings +from typing import ( + Any, + List, + Callable, + Optional, + Union, + TYPE_CHECKING, + overload, +) +from typing_extensions import ParamSpec, deprecated from kazoo.exceptions import ConnectionClosedError, NoNodeError, KazooException -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat from kazoo.retry import KazooRetry +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import IAsyncResult log = logging.getLogger(__name__) _STOP_WATCHING = object() +GenericArgs = ParamSpec("GenericArgs") + -def _ignore_closed(func): +def _ignore_closed( + func: Callable[GenericArgs, None] +) -> Callable[GenericArgs, None]: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: GenericArgs.args, **kwargs: GenericArgs.kwargs) -> None: try: return func(*args, **kwargs) except ConnectionClosedError: @@ -37,6 +57,15 @@ def wrapper(*args, **kwargs): return wrapper +DataWatchFunc = Union[ + Callable[[Optional[bytes], Optional[ZnodeStat]], Optional[bool]], + Callable[ + [Optional[bytes], Optional[ZnodeStat], Optional[WatchedEvent]], + Optional[bool], + ], +] + + class DataWatch(object): """Watches a node for data updates and calls the specified function each time it changes @@ -88,7 +117,39 @@ def my_func(data, stat, event): """ - def __init__(self, client, path, func=None, *args, **kwargs): + @overload + def __init__( + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + ): + ... + + @overload + @deprecated( + "Passing additional arguments to DataWatch is deprecated. " + "ignore_missing_node is now assumed to be True by default, and the " + "event will be sent if the function can handle receiving it" + ) + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): + ... + + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): """Create a data watcher for a path :param client: A zookeeper client. @@ -107,7 +168,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._func = func self._stopped = False self._run_lock = client.handler.lock_object() - self._version = None + self._version: int | None = None self._retry = KazooRetry( max_tries=None, sleep_func=client.handler.sleep_func ) @@ -132,7 +193,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._client.add_listener(self._session_watcher) self._get_data() - def __call__(self, func): + def __call__(self, func: DataWatchFunc) -> DataWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -155,16 +216,28 @@ def __call__(self, func): self._get_data() return func - def _log_func_exception(self, data, stat, event=None): + def _log_func_exception( + self, + data: bytes | None, + stat: ZnodeStat | None, + event: WatchedEvent | None = None, + ) -> None: try: # For backwards compatibility, don't send event to the # callback unless the send_event is set in constructor if not self._ever_called: self._ever_called = True try: - result = self._func(data, stat, event) + # The type ignores here are because mypy can't figure out that + # 1) self._func can't ever be None (fingers crossed) + # 2) the function can be called with 2 arguments or with 3 + # arguments (though that could possibly be done with better + # typing) + result = self._func( # type: ignore[call-arg, misc] + data, stat, event + ) except TypeError: - result = self._func(data, stat) + result = self._func(data, stat) # type: ignore[call-arg, misc] if result is False: self._stopped = True self._func = None @@ -174,7 +247,7 @@ def _log_func_exception(self, data, stat, event=None): raise @_ignore_closed - def _get_data(self, event=None): + def _get_data(self, event: WatchedEvent | None = None) -> None: # Ensure this runs one at a time, possible because the session # watcher may trigger a run with self._run_lock: @@ -183,6 +256,7 @@ def _get_data(self, event=None): initial_version = self._version + stat: ZnodeStat | None try: data, stat = self._retry( self._client.get, self._path, self._watcher @@ -210,18 +284,24 @@ def _get_data(self, event=None): if initial_version != self._version or not self._ever_called: self._log_func_exception(data, stat, event) - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: self._get_data(event=event) - def _set_watch(self, state): + def _set_watch(self, state: KazooState) -> None: with self._run_lock: self._watch_established = state - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.CONNECTED: self._client.handler.spawn(self._get_data) +ChildrenWatchFunc = Union[ + Callable[[List[str]], Optional[bool]], + Callable[[List[str], Optional[WatchedEvent]], Optional[bool]], +] + + class ChildrenWatch(object): """Watches a node for children updates and calls the specified function each time it changes @@ -253,11 +333,11 @@ def my_func(children): def __init__( self, - client, - path, - func=None, - allow_session_lost=True, - send_event=False, + client: KazooClient, + path: str, + func: ChildrenWatchFunc | None = None, + allow_session_lost: bool = True, + send_event: bool = False, ): """Create a children watcher for a path @@ -290,7 +370,7 @@ def __init__( self._watch_established = False self._allow_session_lost = allow_session_lost self._run_lock = client.handler.lock_object() - self._prior_children = None + self._prior_children: list[str] | None = None self._used = False # Register our session listener if we're going to resume @@ -301,7 +381,7 @@ def __init__( self._client.add_listener(self._session_watcher) self._get_children() - def __call__(self, func): + def __call__(self, func: ChildrenWatchFunc) -> ChildrenWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -325,7 +405,7 @@ def __call__(self, func): return func @_ignore_closed - def _get_children(self, event=None): + def _get_children(self, event: WatchedEvent | None = None) -> None: with self._run_lock: # Ensure this runs one at a time if self._stopped: return @@ -351,9 +431,16 @@ def _get_children(self, event=None): try: if self._send_event: - result = self._func(children, event) + # See comment about the type ignore here in DataWatch, + # it's the same issue where mypy can't figure out that the + # function can be called with 1 argument or with 2 + result = self._func( # type: ignore[misc] + children, event # type: ignore[call-arg] + ) else: - result = self._func(children) + result = self._func( # type: ignore[misc] + children # type: ignore[call-arg] + ) if result is False: self._stopped = True self._func = None @@ -363,11 +450,11 @@ def _get_children(self, event=None): log.exception(exc) raise - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: if event.type != "NONE": self._get_children(event) - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state in (KazooState.LOST, KazooState.SUSPENDED): self._watch_established = False elif ( @@ -408,14 +495,16 @@ class PatientChildrenWatch(object): """ - def __init__(self, client, path, time_boundary=30): + def __init__( + self, client: KazooClient, path: str, time_boundary: float = 30 + ): self.client = client self.path = path - self.children = [] + self.children: list[str] = [] self.time_boundary = time_boundary self.children_changed = client.handler.event_object() - def start(self): + def start(self) -> IAsyncResult: """Begin the watching process asynchronously :returns: An :class:`~kazoo.interfaces.IAsyncResult` instance @@ -427,7 +516,7 @@ def start(self): self.client.handler.spawn(self._inner_start) return asy - def _inner_start(self): + def _inner_start(self) -> None: try: while True: async_result = self.client.handler.async_result() @@ -447,6 +536,8 @@ def _inner_start(self): except Exception as exc: self.asy.set_exception(exc) - def _children_watcher(self, async_result, event): + def _children_watcher( + self, async_result: IAsyncResult, event: WatchedEvent + ) -> None: self.children_changed.set() async_result.set(time.monotonic()) diff --git a/kazoo/retry.py b/kazoo/retry.py index fb9e8fc7..9e4e0c9a 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import random import time +from typing import Any, Callable, TypeVar from kazoo.exceptions import ( ConnectionClosedError, @@ -10,7 +13,6 @@ SessionExpiredError, ) - log = logging.getLogger(__name__) @@ -37,18 +39,20 @@ class KazooRetry(object): EXPIRED_EXCEPTIONS = (SessionExpiredError,) + RETRY_RETURN = TypeVar("RETRY_RETURN") + def __init__( self, - max_tries=1, - delay=0.1, - backoff=2, - max_jitter=0.4, - max_delay=60.0, - ignore_expire=True, - sleep_func=time.sleep, - deadline=None, - interrupt=None, - ): + max_tries: int | None = 1, + delay: float = 0.1, + backoff: int = 2, + max_jitter: float = 0.4, + max_delay: float = 60.0, + ignore_expire: bool = True, + sleep_func: Callable[[float], None] = time.sleep, + deadline: float | None = None, + interrupt: Callable[[], bool] | None = None, + ) -> None: """Create a :class:`KazooRetry` instance for retrying function calls. @@ -81,20 +85,22 @@ def __init__( self._attempts = 0 self._cur_delay = delay self.deadline = deadline - self._cur_stoptime = None + self._cur_stoptime: float | None = None self.sleep_func = sleep_func - self.retry_exceptions = self.RETRY_EXCEPTIONS + self.retry_exceptions: tuple[ + type[Exception], ... + ] = self.RETRY_EXCEPTIONS self.interrupt = interrupt if ignore_expire: self.retry_exceptions += self.EXPIRED_EXCEPTIONS - def reset(self): + def reset(self) -> None: """Reset the attempt counter""" self._attempts = 0 self._cur_delay = self.delay self._cur_stoptime = None - def copy(self): + def copy(self) -> KazooRetry: """Return a clone of this retry manager""" obj = KazooRetry( max_tries=self.max_tries, @@ -109,7 +115,12 @@ def copy(self): obj.retry_exceptions = self.retry_exceptions return obj - def __call__(self, func, *args, **kwargs): + def __call__( + self, + func: Callable[..., RETRY_RETURN], + *args: Any, + **kwargs: Any, + ) -> RETRY_RETURN: """Call a function with arguments until it completes without throwing a Kazoo exception diff --git a/kazoo/security.py b/kazoo/security.py index 68399445..1b383795 100644 --- a/kazoo/security.py +++ b/kazoo/security.py @@ -1,14 +1,19 @@ """Kazoo Security""" + +from __future__ import annotations + from base64 import b64encode -from collections import namedtuple import hashlib +from typing import NamedTuple # Represents a Zookeeper ID and ACL object -Id = namedtuple("Id", "scheme id") +class Id(NamedTuple): + scheme: str + id: str -class ACL(namedtuple("ACL", "perms id")): +class ACL(NamedTuple): """An ACL for a Zookeeper Node An ACL object is created by using an :class:`Id` object along with @@ -17,8 +22,11 @@ class ACL(namedtuple("ACL", "perms id")): the desired scheme, id, and permissions. """ + perms: int + id: Id + @property - def acl_list(self): + def acl_list(self) -> list[str]: perms = [] if self.perms & Permissions.ALL == Permissions.ALL: perms.append("ALL") @@ -35,7 +43,7 @@ def acl_list(self): perms.append("ADMIN") return perms - def __repr__(self): + def __repr__(self) -> str: return "ACL(perms=%r, acl_list=%s, id=%r)" % ( self.perms, self.acl_list, @@ -62,7 +70,7 @@ class Permissions(object): READ_ACL_UNSAFE = [ACL(Permissions.READ, ANYONE_ID_UNSAFE)] -def make_digest_acl_credential(username, password): +def make_digest_acl_credential(username: str, password: str) -> str: """Create a SHA1 digest credential. .. note:: @@ -80,15 +88,15 @@ def make_digest_acl_credential(username, password): def make_acl( - scheme, - credential, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + scheme: str, + credential: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Given a scheme and credential, return an :class:`ACL` object appropriate for use with Kazoo. @@ -131,15 +139,15 @@ def make_acl( def make_digest_acl( - username, - password, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + username: str, + password: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Create a digest ACL for Zookeeper with the given permissions This method combines :meth:`make_digest_acl_credential` and diff --git a/kazoo/testing/common.py b/kazoo/testing/common.py index 4f702e42..6f97c547 100644 --- a/kazoo/testing/common.py +++ b/kazoo/testing/common.py @@ -18,7 +18,7 @@ # # You should have received a copy of the GNU Lesser General Public License # along with txzookeeper. If not, see . - +from __future__ import annotations import code from collections import namedtuple @@ -37,16 +37,19 @@ import OpenSSL import jks +from types import FrameType +from typing import Any, Iterator log = logging.getLogger(__name__) -def debug(sig, frame): +def debug(sig: int, frame: FrameType | None) -> None: """Interrupt running process, and provide a python prompt for interactive debugging.""" d = {"_frame": frame} # Allow access to frame object. - d.update(frame.f_globals) # Unless shadowed by global - d.update(frame.f_locals) + if frame is not None: + d.update(frame.f_globals) # Unless shadowed by global + d.update(frame.f_locals) i = code.InteractiveConsole(d) message = "Signal received : entering python shell.\nTraceback:\n" @@ -54,7 +57,7 @@ def debug(sig, frame): i.interact(message) -def listen(): +def listen() -> None: if os.name != "nt": # SIGUSR1 is not supported on Windows signal.signal(signal.SIGUSR1, debug) # Register handler @@ -62,7 +65,7 @@ def listen(): listen() -def to_java_compatible_path(path): +def to_java_compatible_path(path: str) -> str: if os.name == "nt": path = path.replace("\\", "/") return path @@ -85,14 +88,14 @@ class ManagedZooKeeper(object): def __init__( self, - software_path, - server_info, - peers=(), - classpath=None, - configuration_entries=(), - java_system_properties=(), - jaas_config=None, - ssl_configuration=None, + software_path: str, + server_info: ServerInfo, + peers: list[ServerInfo], + classpath: str, + configuration_entries: list[str], + java_system_properties: list[str], + jaas_config: str | None = None, + ssl_configuration: dict[str, Any] | None = None, ): """Define the ZooKeeper test instance. @@ -104,8 +107,8 @@ def __init__( self.server_info = server_info self.host = "127.0.0.1" self.peers = peers - self.working_path = tempfile.mkdtemp() - self._running = False + self.working_path: str = tempfile.mkdtemp() + self._running: bool = False self.configuration_entries = configuration_entries self.java_system_properties = java_system_properties self.jaas_config = jaas_config @@ -113,7 +116,7 @@ def __init__( ssl_configuration if ssl_configuration is not None else {} ) - def run(self): + def run(self) -> None: """Run the ZooKeeper instance under a temporary directory. Writes ZK log messages to zookeeper.log in the current directory. @@ -257,7 +260,7 @@ def run(self): self._running = True @property - def classpath(self): + def classpath(self) -> str: """Get the classpath necessary to run ZooKeeper.""" if self._classpath: @@ -290,28 +293,28 @@ def classpath(self): return os.pathsep.join(jars) @property - def address(self): + def address(self) -> str: """Get the address of the ZooKeeper instance.""" return "%s:%s" % (self.host, self.client_port) @property - def secure_address(self): + def secure_address(self) -> str: """Get the address of the SSL ZooKeeper instance.""" return "%s:%s" % (self.host, self.secure_client_port) @property - def running(self): + def running(self) -> bool: return self._running @property - def client_port(self): + def client_port(self) -> Any: return self.server_info.client_port @property - def secure_client_port(self): + def secure_client_port(self) -> Any: return self.server_info.secure_client_port - def reset(self): + def reset(self) -> None: """Stop the zookeeper instance, cleaning out its on disk-data.""" self.stop() shutil.rmtree(os.path.join(self.working_path, "data"), True) @@ -319,7 +322,7 @@ def reset(self): with open(os.path.join(self.working_path, "data", "myid"), "w") as fh: fh.write(str(self.server_info.server_id)) - def stop(self): + def stop(self) -> None: """Stop the Zookeeper instance, retaining on disk state.""" if not self.running: return @@ -335,14 +338,14 @@ def stop(self): ) self._running = False - def destroy(self): + def destroy(self) -> None: """Stop the ZooKeeper instance and destroy its on disk-state""" # called by at exit handler, reimport to avoid cleanup race. self.stop() shutil.rmtree(self.working_path, True) - def get_logs(self, num_lines=100): + def get_logs(self, num_lines: int = 100) -> list[str]: log_path = pathlib.Path(self.working_path, "zookeeper.log") if log_path.exists(): log_file = log_path.open("r") @@ -354,24 +357,24 @@ def get_logs(self, num_lines=100): class ZookeeperCluster(object): def __init__( self, - install_path=None, - classpath=None, - size=3, - port_offset=20000, - observer_start_id=-1, - configuration_entries=(), - java_system_properties=(), - jaas_config=None, + install_path: str, + classpath: str, + size: int, + port_offset: int, + observer_start_id: int, + configuration_entries: list[str], + java_system_properties: list[str], + jaas_config: str | None, ): self._install_path = install_path self._classpath = classpath self._servers = [] - self._ssl_configuration = {} + self._ssl_configuration: dict[str, Any] = {} self.perform_ssl_certs_generation() # Calculate ports and peer group port = port_offset - peers = [] + peers: list[ServerInfo] = [] for i in range(size): server_id = i + 1 @@ -408,13 +411,13 @@ def __init__( ) ) - def __getitem__(self, k): + def __getitem__(self, k: int) -> ManagedZooKeeper: return self._servers[k] - def __iter__(self): + def __iter__(self) -> Iterator[ManagedZooKeeper]: return iter(self._servers) - def start(self): + def start(self) -> None: # Zookeeper client expresses a preference for either lower ports or # lexicographical ordering of hosts, to ensure that all servers have a # chance to startup, start them in reverse order. @@ -427,26 +430,26 @@ def start(self): time.sleep(2) - def stop(self): + def stop(self) -> None: for server in self: server.stop() self._servers = [] - def terminate(self): + def terminate(self) -> None: for server in self: server.destroy() - def reset(self): + def reset(self) -> None: for server in self: server.reset() - def get_logs(self): + def get_logs(self) -> list[str]: logs = [] for server in self: logs += server.get_logs() return logs - def perform_ssl_certs_generation(self): + def perform_ssl_certs_generation(self) -> None: if self._ssl_configuration: return @@ -542,7 +545,7 @@ def perform_ssl_certs_generation(self): "keystore": keystore, } - def get_ssl_client_configuration(self): + def get_ssl_client_configuration(self) -> dict[str, Any]: if not self._ssl_configuration: raise RuntimeError("SSL not configured yet.") return { diff --git a/kazoo/testing/harness.py b/kazoo/testing/harness.py index bb77f071..10542c99 100644 --- a/kazoo/testing/harness.py +++ b/kazoo/testing/harness.py @@ -1,10 +1,17 @@ """Kazoo testing harnesses""" +from __future__ import annotations + import atexit import logging import os import uuid import unittest +from typing import Any, Callable, Literal, cast, TYPE_CHECKING + +if TYPE_CHECKING: + import kazoo.interfaces + from kazoo.client import KazooClient from kazoo.exceptions import KazooException from kazoo.protocol.connection import _CONNECTION_DROP, _SESSION_EXPIRED @@ -13,8 +20,8 @@ log = logging.getLogger(__name__) -CLUSTER = None -CLUSTER_CONF = None +CLUSTER: ZookeeperCluster | None = None +CLUSTER_CONF: dict[str, Any] | None = None CLUSTER_DEFAULTS = { "ZOOKEEPER_PORT_OFFSET": 20000, "ZOOKEEPER_CLUSTER_SIZE": 3, @@ -24,7 +31,8 @@ MAX_INIT_TRIES = 5 -def get_global_cluster(): +# FIXME use a typeddict for cluster conf and cluster defaults +def get_global_cluster() -> ZookeeperCluster: global CLUSTER, CLUSTER_CONF cluster_conf = { k: os.environ.get(k, CLUSTER_DEFAULTS.get(k)) @@ -47,16 +55,22 @@ def get_global_cluster(): CLUSTER.terminate() CLUSTER = None # Create a new cluster - ZK_HOME = cluster_conf.get("ZOOKEEPER_PATH") - ZK_CLASSPATH = cluster_conf.get("ZOOKEEPER_CLASSPATH") - ZK_PORT_OFFSET = int(cluster_conf.get("ZOOKEEPER_PORT_OFFSET")) - ZK_CLUSTER_SIZE = int(cluster_conf.get("ZOOKEEPER_CLUSTER_SIZE")) - ZK_VERSION = cluster_conf.get("ZOOKEEPER_VERSION") - if "-" in ZK_VERSION: + ZK_HOME = cast(str, cluster_conf.get("ZOOKEEPER_PATH")) + ZK_CLASSPATH = cast(str, cluster_conf.get("ZOOKEEPER_CLASSPATH")) + ZK_PORT_OFFSET = int( # type: ignore[call-overload] + cluster_conf.get("ZOOKEEPER_PORT_OFFSET") + ) + ZK_CLUSTER_SIZE = int( # type: ignore[call-overload] + cluster_conf.get("ZOOKEEPER_CLUSTER_SIZE") + ) + ZK_VERSION_STR = cast(str, cluster_conf.get("ZOOKEEPER_VERSION")) + if "-" in ZK_VERSION_STR: # Ignore pre-release markers like -alpha - ZK_VERSION = ZK_VERSION.split("-")[0] - ZK_VERSION = tuple([int(n) for n in ZK_VERSION.split(".")]) - ZK_OBSERVER_START_ID = int(cluster_conf.get("ZOOKEEPER_OBSERVER_START_ID")) + ZK_VERSION_STR = ZK_VERSION_STR.split("-")[0] + ZK_VERSION = tuple([int(n) for n in ZK_VERSION_STR.split(".")]) + ZK_OBSERVER_START_ID = int( # type: ignore[call-overload] + cluster_conf.get("ZOOKEEPER_OBSERVER_START_ID") + ) assert ZK_HOME or ZK_CLASSPATH or ZK_VERSION, ( "Either ZOOKEEPER_PATH or ZOOKEEPER_CLASSPATH or " @@ -65,8 +79,8 @@ def get_global_cluster(): ) if ZK_VERSION >= (3, 5): - ZOOKEEPER_LOCAL_SESSION_RO = cluster_conf.get( - "ZOOKEEPER_LOCAL_SESSION_RO" + ZOOKEEPER_LOCAL_SESSION_RO = cast( + str, cluster_conf.get("ZOOKEEPER_LOCAL_SESSION_RO") ) additional_configuration_entries = [ "4lw.commands.whitelist=*", @@ -140,69 +154,78 @@ class KazooTestHarness(unittest.TestCase): Example:: class MyTestCase(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: self.setup_zookeeper() # additional test setup - def tearDown(self): + def tearDown(self)-> None: self.teardown_zookeeper() - def test_something(self): + def test_something(self) -> None: something_that_needs_a_kazoo_client(self.client) - def test_something_else(self): + def test_something_else(self) -> None: something_that_needs_zk_servers(self.servers) """ DEFAULT_CLIENT_TIMEOUT = 15 - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any): super(KazooTestHarness, self).__init__(*args, **kw) - self.client = None - self._clients = [] + self._client: KazooClient | None = None + self._clients: list[KazooClient] = [] @property - def cluster(self): + def cluster(self) -> ZookeeperCluster: return get_global_cluster() - def log(self, level, msg, *args, **kwargs): + @property + def client(self) -> KazooClient: + assert self._client is not None + return self._client + + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: log.log(level, msg, *args, **kwargs) @property - def servers(self): + def servers(self) -> str: return ",".join([s.address for s in self.cluster]) @property - def secure_servers(self): + def secure_servers(self) -> str: return ",".join([s.secure_address for s in self.cluster]) - def _get_nonchroot_client(self): + def _get_nonchroot_client(self) -> KazooClient: c = KazooClient(self.servers) self._clients.append(c) return c - def _get_client(self, **client_options): + def _get_client(self, **client_options: Any) -> KazooClient: if "timeout" not in client_options: client_options["timeout"] = self.DEFAULT_CLIENT_TIMEOUT c = KazooClient(self.hosts, **client_options) self._clients.append(c) return c - def lose_connection(self, event_factory): + def lose_connection( + self, event_factory: Callable[[], kazoo.interfaces.Event] + ) -> None: """Force client to lose connection with server""" self.__break_connection( _CONNECTION_DROP, KazooState.SUSPENDED, event_factory ) - def expire_session(self, event_factory): + def expire_session( + self, event_factory: Callable[[], kazoo.interfaces.Event] + ) -> None: """Force ZK to expire a client session""" self.__break_connection( _SESSION_EXPIRED, KazooState.LOST, event_factory ) - def setup_zookeeper(self, **client_options): + def setup_zookeeper(self, **client_options: Any) -> None: """Create a ZK cluster and chrooted :class:`KazooClient` The cluster will only be created on the first invocation and won't be @@ -242,11 +265,11 @@ def setup_zookeeper(self, **client_options): self.hosts = self.secure_servers + namespace else: self.hosts = self.servers + namespace - self.client = self._get_client(**client_options) + self._client = self._get_client(**client_options) self.client.start() self.client.ensure_path("/") - def teardown_zookeeper(self): + def teardown_zookeeper(self) -> None: """Reset and cleanup the zookeeper cluster that was started.""" while self._clients: c = self._clients.pop() @@ -256,23 +279,29 @@ def teardown_zookeeper(self): log.exception("Failed stopping client %s", c) finally: c.close() - self.client = None - - def __break_connection(self, break_event, expected_state, event_factory): + self._client = None + + def __break_connection( + self, + break_event: object, + expected_state: KazooState, + event_factory: Callable[[], kazoo.interfaces.Event], + ) -> None: """Break ZooKeeper connection using the specified event.""" lost = event_factory() safe = event_factory() - def watch_loss(state): + def watch_loss(state: KazooState) -> Literal[True] | None: if state == expected_state: lost.set() elif lost.is_set() and state == KazooState.CONNECTED: safe.set() return True + return None self.client.add_listener(watch_loss) - self.client._call(break_event, None) + self.client._call(break_event, None) # type: ignore[arg-type] lost.wait(5) if not lost.is_set(): @@ -286,14 +315,14 @@ def watch_loss(state): class KazooTestCase(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: self.setup_zookeeper() - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: cluster = get_global_cluster() if cluster is not None: cluster.terminate() diff --git a/kazoo/tests/conftest.py b/kazoo/tests/conftest.py index 931fd84f..94da635a 100644 --- a/kazoo/tests/conftest.py +++ b/kazoo/tests/conftest.py @@ -1,9 +1,11 @@ import logging +from typing import Any + log = logging.getLogger(__name__) -def pytest_exception_interact(node, call, report): +def pytest_exception_interact(node: Any, call: Any, report: Any) -> None: try: cluster = node._testcase.cluster log.error("Zookeeper cluster logs:") diff --git a/kazoo/tests/test_barrier.py b/kazoo/tests/test_barrier.py index 5f79e861..f77c3169 100644 --- a/kazoo/tests/test_barrier.py +++ b/kazoo/tests/test_barrier.py @@ -1,35 +1,37 @@ +from __future__ import annotations + import threading from kazoo.testing import KazooTestCase class KazooBarrierTests(KazooTestCase): - def test_barrier_not_exist(self): + def test_barrier_not_exist(self) -> None: b = self.client.Barrier("/some/path") assert b.wait() - def test_barrier_exists(self): + def test_barrier_exists(self) -> None: b = self.client.Barrier("/some/path") b.create() assert not b.wait(0) b.remove() assert b.wait() - def test_remove_nonexistent_barrier(self): + def test_remove_nonexistent_barrier(self) -> None: b = self.client.Barrier("/some/path") assert not b.remove() class KazooDoubleBarrierTests(KazooTestCase): - def test_basic_barrier(self): + def test_basic_barrier(self) -> None: b = self.client.DoubleBarrier("/some/path", 1) assert not b.participating b.enter() assert b.participating - b.leave() + b.leave() # type: ignore[unreachable] assert not b.participating - def test_two_barrier(self): + def test_two_barrier(self) -> None: av = threading.Event() ev = threading.Event() bv = threading.Event() @@ -37,14 +39,14 @@ def test_two_barrier(self): b1 = self.client.DoubleBarrier("/some/path", 2) b2 = self.client.DoubleBarrier("/some/path", 2) - def make_barrier_one(): + def make_barrier_one() -> None: b1.enter() ev.set() release_all.wait() b1.leave() ev.set() - def make_barrier_two(): + def make_barrier_two() -> None: bv.wait() b2.enter() av.set() @@ -65,7 +67,7 @@ def make_barrier_two(): av.wait() ev.wait() assert b1.participating - assert b2.participating + assert b2.participating # type: ignore[unreachable] av.clear() ev.clear() @@ -78,7 +80,7 @@ def make_barrier_two(): t1.join() t2.join() - def test_three_barrier(self): + def test_three_barrier(self) -> None: av = threading.Event() ev = threading.Event() bv = threading.Event() @@ -87,14 +89,14 @@ def test_three_barrier(self): b2 = self.client.DoubleBarrier("/some/path", 3) b3 = self.client.DoubleBarrier("/some/path", 3) - def make_barrier_one(): + def make_barrier_one() -> None: b1.enter() ev.set() release_all.wait() b1.leave() ev.set() - def make_barrier_two(): + def make_barrier_two() -> None: bv.wait() b2.enter() av.set() @@ -119,7 +121,7 @@ def make_barrier_two(): av.wait() assert b1.participating - assert b2.participating + assert b2.participating # type: ignore[unreachable] assert b3.participating av.clear() @@ -135,7 +137,7 @@ def make_barrier_two(): t1.join() t2.join() - def test_barrier_existing_parent_node(self): + def test_barrier_existing_parent_node(self) -> None: b = self.client.DoubleBarrier("/some/path", 1) assert b.participating is False self.client.create("/some", ephemeral=True) @@ -143,7 +145,7 @@ def test_barrier_existing_parent_node(self): b.enter() assert b.participating is False - def test_barrier_existing_node(self): + def test_barrier_existing_node(self) -> None: b = self.client.DoubleBarrier("/some", 1) assert b.participating is False self.client.ensure_path(b.path) @@ -151,4 +153,4 @@ def test_barrier_existing_node(self): # the barrier will re-use an existing node b.enter() assert b.participating is True - b.leave() + b.leave() # type: ignore[unreachable] diff --git a/kazoo/tests/test_build.py b/kazoo/tests/test_build.py index 01dbf873..7b36ddf1 100644 --- a/kazoo/tests/test_build.py +++ b/kazoo/tests/test_build.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pytest @@ -6,14 +8,14 @@ class TestBuildEnvironment(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) if not os.environ.get("CI"): pytest.skip("Only run build config tests on CI.") - def test_zookeeper_version(self): - server_version = self.client.server_version() - server_version = ".".join([str(i) for i in server_version]) + def test_zookeeper_version(self) -> None: + server_version1 = self.client.server_version() + server_version = ".".join([str(i) for i in server_version1]) env_version = os.environ.get("ZOOKEEPER_VERSION") if env_version: if "-" in env_version: diff --git a/kazoo/tests/test_cache.py b/kazoo/tests/test_cache.py index 7251db64..a5bf2d78 100644 --- a/kazoo/tests/test_cache.py +++ b/kazoo/tests/test_cache.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import gc import importlib import sys import uuid +from typing import Any, TYPE_CHECKING +from queue import Queue from unittest.mock import patch, call, Mock import pytest @@ -11,6 +15,11 @@ from kazoo.exceptions import KazooException from kazoo.recipe.cache import TreeCache, TreeNode, TreeEvent +if TYPE_CHECKING: + from kazoo.handlers.gevent import SequentialGeventHandler + from kazoo.handlers.eventlet import SequentialEventletHandler + from kazoo.handlers.threading import SequentialThreadingHandler + class KazooAdaptiveHandlerTestCase(KazooTestHarness): HANDLERS = ( @@ -19,15 +28,22 @@ class KazooAdaptiveHandlerTestCase(KazooTestHarness): ("kazoo.handlers.threading", "SequentialThreadingHandler"), ) - def setUp(self): + def setUp(self) -> None: self.handler = self.choose_an_installed_handler() self.setup_zookeeper(handler=self.handler) - def tearDown(self): + def tearDown(self) -> None: self.handler = None self.teardown_zookeeper() - def choose_an_installed_handler(self): + def choose_an_installed_handler( + self, + ) -> ( + SequentialGeventHandler + | SequentialEventletHandler + | SequentialThreadingHandler + | None + ): for handler_module, handler_class in self.HANDLERS: if ( handler_module == "kazoo.handlers.gevent" @@ -40,39 +56,60 @@ def choose_an_installed_handler(self): except ImportError: continue else: - return cls() + # FIXME Should be no-any-return but hound is a dog + return cls() # type: ignore raise ImportError("No available handler") class KazooTreeCacheTests(KazooAdaptiveHandlerTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooTreeCacheTests, self).setUp() - self._event_queue = self.client.handler.queue_impl() + self._event_queue: Queue[TreeEvent] = self.client.handler.queue_impl() self._error_queue = self.client.handler.queue_impl() - self.path = None - self.cache = None + self._path: str | None = None + self._cache: TreeCache | None = None - def tearDown(self): + def tearDown(self) -> None: if not self._error_queue.empty(): try: raise self._error_queue.get() except FakeException: pass - if self.cache is not None: - self.cache.close() - self.cache = None + if self._cache is not None: + self._cache.close() + self._cache = None super(KazooTreeCacheTests, self).tearDown() - def make_cache(self): - if self.cache is None: - self.path = "/" + uuid.uuid4().hex - self.cache = TreeCache(self.client, self.path) - self.cache.listen(lambda event: self._event_queue.put(event)) - self.cache.listen_fault(lambda error: self._error_queue.put(error)) - self.cache.start() - return self.cache - - def wait_cache(self, expect=None, since=None, timeout=10): + def make_cache(self) -> TreeCache: + if self._cache is None: + self._path = "/" + uuid.uuid4().hex + self._cache = TreeCache(self.client, self.path) + self._cache.listen(lambda event: self._event_queue.put(event)) + self._cache.listen_fault( + lambda error: self._error_queue.put(error) + ) + self._cache.start() + return self._cache + + # FIXME This is entirely for the purpose of minimising code changes. + # Calling make_cache twice should be an error and the return value + # should be used, not stored. + @property + def cache(self) -> TreeCache: + assert self._cache is not None + return self._cache + + @property + def path(self) -> str: + assert self._path is not None + return self._path + + def wait_cache( + self, + expect: int | None = None, + since: int | None = None, + timeout: float = 10, + ) -> TreeEvent | None: started = since is None while True: event = self._event_queue.get(timeout=timeout) @@ -83,25 +120,25 @@ def wait_cache(self, expect=None, since=None, timeout=10): if event.event_type == since: started = True if expect is None: - return + return None - def spy_client(self, method_name): + def spy_client(self, method_name: str) -> Any: method = getattr(self.client, method_name) return patch.object(self.client, method_name, wraps=method) - def _wait_gc(self): + def _wait_gc(self) -> None: # trigger switching on some coroutine handlers self.client.handler.sleep_func(0.1) completion_queue = getattr(self.handler, "completion_queue", None) if completion_queue is not None: - while not self.client.handler.completion_queue.empty(): + while not completion_queue.empty(): self.client.handler.sleep_func(0.1) for gen in range(3): gc.collect(gen) - def count_tree_node(self): + def count_tree_node(self) -> int: # inspect GC and count tree nodes for checking memory leak for retry in range(10): result = set() @@ -112,28 +149,29 @@ def count_tree_node(self): return list(result)[0] raise RuntimeError("could not count refs exactly") - def test_start(self): + def test_start(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) stat = self.client.exists(self.path) + assert stat is not None assert stat.version == 0 assert self.cache._state == TreeCache.STATE_STARTED assert self.cache._root._state == TreeNode.STATE_LIVE - def test_start_started(self): + def test_start_started(self) -> None: self.make_cache() with pytest.raises(KazooException): self.cache.start() - def test_start_closed(self): + def test_start_closed(self) -> None: self.make_cache() self.cache.close() with pytest.raises(KazooException): self.cache.start() - def test_close(self): + def test_close(self) -> None: assert self.count_tree_node() == 0 self.make_cache() @@ -192,11 +230,13 @@ def test_close(self): == stub_child_watcher ) + # FIXME This looks pointless at best. + self._cache = None + # should not be any leaked memory (tree node) here - self.cache = None assert self.count_tree_node() == 0 - def test_delete_operation(self): + def test_delete_operation(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) @@ -225,12 +265,13 @@ def test_delete_operation(self): # should not be any leaked memory (tree node) here assert self.count_tree_node() == 1 - def test_children_operation(self): + def test_children_operation(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/test_children", b"test_children_1") event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_type == TreeEvent.NODE_ADDED assert event.event_data.path == self.path + "/test_children" assert event.event_data.data == b"test_children_1" @@ -238,6 +279,7 @@ def test_children_operation(self): self.client.set(self.path + "/test_children", b"test_children_2") event = self.wait_cache(TreeEvent.NODE_UPDATED) + assert event is not None assert event.event_type == TreeEvent.NODE_UPDATED assert event.event_data.path == self.path + "/test_children" assert event.event_data.data == b"test_children_2" @@ -245,18 +287,20 @@ def test_children_operation(self): self.client.delete(self.path + "/test_children") event = self.wait_cache(TreeEvent.NODE_REMOVED) + assert event is not None assert event.event_type == TreeEvent.NODE_REMOVED assert event.event_data.path == self.path + "/test_children" assert event.event_data.data == b"test_children_2" assert event.event_data.stat.version == 1 - def test_subtree_operation(self): + def test_subtree_operation(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo/bar/baz", makepath=True) for relative_path in ("/foo", "/foo/bar", "/foo/bar/baz"): event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_type == TreeEvent.NODE_ADDED assert event.event_data.path == self.path + relative_path assert event.event_data.data == b"" @@ -265,10 +309,11 @@ def test_subtree_operation(self): self.client.delete(self.path + "/foo", recursive=True) for relative_path in ("/foo/bar/baz", "/foo/bar", "/foo"): event = self.wait_cache(TreeEvent.NODE_REMOVED) + assert event is not None assert event.event_type == TreeEvent.NODE_REMOVED assert event.event_data.path == self.path + relative_path - def test_get_data(self): + def test_get_data(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo/bar/baz", b"@", makepath=True) @@ -277,19 +322,27 @@ def test_get_data(self): self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, "_client"): # disable any remote operation - assert cache.get_data(self.path).data == b"" - assert cache.get_data(self.path).stat.version == 0 - - assert cache.get_data(self.path + "/foo").data == b"" - assert cache.get_data(self.path + "/foo").stat.version == 0 - - assert cache.get_data(self.path + "/foo/bar").data == b"" - assert cache.get_data(self.path + "/foo/bar").stat.version == 0 - - assert cache.get_data(self.path + "/foo/bar/baz").data == b"@" - assert cache.get_data(self.path + "/foo/bar/baz").stat.version == 0 - - def test_get_children(self): + node = cache.get_data(self.path) + assert node is not None + assert node.data == b"" + assert node.stat.version == 0 + + node = cache.get_data(self.path + "foo") + assert node is not None + assert node.data == b"" + assert node.stat.version == 0 + + node = cache.get_data(self.path + "foo/bar") + assert node is not None + assert node.data == b"" + assert node.stat.version == 0 + + node = cache.get_data(self.path + "foo/bar/baz") + assert node is not None + assert node.data == b"@" + assert node.stat.version == 0 + + def test_get_children(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo/bar/baz", b"@", makepath=True) @@ -307,38 +360,39 @@ def test_get_children(self): assert cache.get_children(self.path + "/foo") == frozenset(["bar"]) assert cache.get_children(self.path) == frozenset(["foo"]) - def test_get_data_out_of_tree(self): + def test_get_data_out_of_tree(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with pytest.raises(ValueError): self.cache.get_data("/out_of_tree") - def test_get_children_out_of_tree(self): + def test_get_children_out_of_tree(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with pytest.raises(ValueError): self.cache.get_children("/out_of_tree") - def test_get_data_no_node(self): + def test_get_data_no_node(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, "_client"): # disable any remote operation assert cache.get_data(self.path + "/non_exists") is None - def test_get_children_no_node(self): + def test_get_children_no_node(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, "_client"): # disable any remote operation assert cache.get_children(self.path + "/non_exists") is None - def test_session_reconnected(self): + def test_session_reconnected(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo") event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_data.path == self.path + "/foo" with self.spy_client("get_async") as get_data: @@ -382,13 +436,14 @@ def test_session_reconnected(self): any_order=True, ) - def test_root_recreated(self): + def test_root_recreated(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # remove root node self.client.delete(self.path) event = self.wait_cache(TreeEvent.NODE_REMOVED) + assert event is not None assert event.event_type == TreeEvent.NODE_REMOVED assert event.event_data.data == b"" assert event.event_data.path == self.path @@ -397,6 +452,7 @@ def test_root_recreated(self): # re-create root node self.client.ensure_path(self.path) event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_type == TreeEvent.NODE_ADDED assert event.event_data.data == b"" assert event.event_data.path == self.path @@ -406,7 +462,7 @@ def test_root_recreated(self): "unexpected outstanding ops %r" % self.cache._outstanding_ops ) - def test_exception_handler(self): + def test_exception_handler(self) -> None: error_value = FakeException() error_handler = Mock() @@ -419,7 +475,7 @@ def test_exception_handler(self): self.cache.close() error_handler.assert_called_once_with(error_value) - def test_exception_suppressed(self): + def test_exception_suppressed(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index 3f1748c4..bc7a78e0 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import socket import tempfile @@ -5,11 +7,14 @@ import time import uuid import unittest + +from typing import Any from unittest.mock import Mock, MagicMock, patch import pytest -from kazoo.testing import KazooTestCase + +from kazoo.client import KazooClient from kazoo.exceptions import ( AuthFailedError, BadArgumentsError, @@ -19,27 +24,45 @@ ConnectionLoss, InvalidACLError, NoAuthError, + NoChildrenForEphemeralsError, NoNodeError, NodeExistsError, + RolledBackError, SessionExpiredError, KazooException, ) +from kazoo.interfaces import IAsyncResult +from kazoo.protocol.states import KazooState, KeeperState, WatchedEvent +from kazoo.handlers.threading import ( + SequentialThreadingHandler, + KazooTimeoutError, +) from kazoo.protocol.connection import _CONNECTION_DROP -from kazoo.protocol.states import KeeperState, KazooState +from kazoo.retry import KazooRetry +from kazoo.security import ( + make_digest_acl_credential, + CREATOR_ALL_ACL, + make_digest_acl, + ACL, + OPEN_ACL_UNSAFE, +) + +from kazoo.testing import KazooTestCase +from kazoo.testing.common import ManagedZooKeeper from kazoo.tests.util import CI_ZK_VERSION class TestClientTransitions(KazooTestCase): @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def test_connection_and_disconnection(self): + def test_connection_and_disconnection(self) -> None: states = [] rc = threading.Event() @self.client.add_listener - def listener(state): + def listener(state: KazooState) -> None: states.append(state) if state == KazooState.CONNECTED: rc.set() @@ -61,18 +84,14 @@ def listener(state): class TestClientConstructor(unittest.TestCase): - def _makeOne(self, *args, **kw): - from kazoo.client import KazooClient - + def _makeOne(self, *args: Any, **kw: Any) -> KazooClient: return KazooClient(*args, **kw) - def test_invalid_handler(self): - from kazoo.handlers.threading import SequentialThreadingHandler - + def test_invalid_handler(self) -> None: with pytest.raises(ConfigurationError): self._makeOne(handler=SequentialThreadingHandler) - def test_chroot(self): + def test_chroot(self) -> None: assert self._makeOne(hosts="127.0.0.1:2181/").chroot == "" assert self._makeOne(hosts="127.0.0.1:2181/a").chroot == "/a" assert self._makeOne(hosts="127.0.0.1/a").chroot == "/a" @@ -82,35 +101,31 @@ def test_chroot(self): == "/a/b" ) - def test_connection_timeout(self): - from kazoo.handlers.threading import KazooTimeoutError - + def test_connection_timeout(self) -> None: client = self._makeOne(hosts="127.0.0.1:9") assert client.handler.timeout_exception is KazooTimeoutError with pytest.raises(KazooTimeoutError): client.start(0.1) - def test_ordered_host_selection(self): + def test_ordered_host_selection(self) -> None: client = self._makeOne( hosts="127.0.0.1:9,127.0.0.2:9/a", randomize_hosts=False ) hosts = [h for h in client.hosts] assert hosts == [("127.0.0.1", 9), ("127.0.0.2", 9)] - def test_invalid_hostname(self): + def test_invalid_hostname(self) -> None: client = self._makeOne(hosts="nosuchhost/a") timeout = client.handler.timeout_exception with pytest.raises(timeout): client.start(0.1) - def test_another_invalid_hostname(self): + def test_another_invalid_hostname(self) -> None: with pytest.raises(ValueError): self._makeOne(hosts="/nosuchhost/a") - def test_retry_options_dict(self): - from kazoo.retry import KazooRetry - + def test_retry_options_dict(self) -> None: client = self._makeOne( command_retry=dict(max_tries=99), connection_retry=dict(delay=99) ) @@ -121,12 +136,10 @@ def test_retry_options_dict(self): class TestAuthentication(KazooTestCase): - def _makeAuth(self, *args, **kwargs): - from kazoo.security import make_digest_acl - + def _makeAuth(self, *args: Any, **kwargs: Any) -> ACL: return make_digest_acl(*args, **kwargs) - def test_auth(self): + def test_auth(self) -> None: username = uuid.uuid4().hex password = uuid.uuid4().hex @@ -162,7 +175,8 @@ def test_auth(self): eve.stop() eve.close() - def test_connect_auth(self): + def test_connect_auth(self) -> None: + username = uuid.uuid4().hex password = uuid.uuid4().hex @@ -184,7 +198,7 @@ def test_connect_auth(self): client.stop() client.close() - def test_unicode_auth(self): + def test_unicode_auth(self) -> None: username = r"xe4/\hm" password = r"/\xe4hm" digest_auth = "%s:%s" % (username, password) @@ -217,17 +231,19 @@ def test_unicode_auth(self): eve.stop() eve.close() - def test_invalid_auth(self): + def test_invalid_auth(self) -> None: client = self._get_client() client.start() with pytest.raises(TypeError): - client.add_auth("digest", ("user", "pass")) + client.add_auth( + "digest", ("user", "pass") # type: ignore[arg-type] + ) with pytest.raises(TypeError): - client.add_auth(None, ("user", "pass")) + client.add_auth(None, ("user", "pass")) # type: ignore[arg-type] - def test_async_auth(self): + def test_async_auth(self) -> None: client = self._get_client() client.start() username = uuid.uuid4().hex @@ -236,7 +252,7 @@ def test_async_auth(self): result = client.add_auth_async("digest", digest_auth) assert result.get() is True - def test_async_auth_failure(self): + def test_async_auth_failure(self) -> None: client = self._get_client() client.start() username = uuid.uuid4().hex @@ -246,10 +262,11 @@ def test_async_auth_failure(self): with pytest.raises(AuthFailedError): client.add_auth("unknown-scheme", digest_auth) - def test_add_auth_on_reconnect(self): + def test_add_auth_on_reconnect(self) -> None: client = self._get_client() client.start() client.add_auth("digest", "jsmith:jsmith") + assert client._connection._socket is not None client._connection._socket.shutdown(socket.SHUT_RDWR) while not client.connected: time.sleep(0.1) @@ -258,14 +275,14 @@ def test_add_auth_on_reconnect(self): class TestConnection(KazooTestCase): @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() @staticmethod - def make_condition(): + def make_condition() -> threading.Condition: return threading.Condition() - def test_chroot_warning(self): + def test_chroot_warning(self) -> None: k = self._get_nonchroot_client() k.chroot = "abba" try: @@ -275,12 +292,11 @@ def test_chroot_warning(self): finally: k.stop() - def test_session_expire(self): - from kazoo.protocol.states import KazooState + def test_session_expire(self) -> None: cv = self.make_event() - def watch_events(event): + def watch_events(event: KazooState) -> None: if event == KazooState.LOST: cv.set() @@ -289,17 +305,15 @@ def watch_events(event): cv.wait(3) assert cv.is_set() - def test_bad_session_expire(self): - from kazoo.protocol.states import KazooState + def test_bad_session_expire(self) -> None: cv = self.make_event() ab = self.make_event() - def watch_events(event): + def watch_events(event: KazooState) -> None: if event == KazooState.LOST: ab.set() raise Exception("oops") - cv.set() self.client.add_listener(watch_events) self.expire_session(self.make_event) @@ -308,13 +322,12 @@ def watch_events(event): cv.wait(0.5) assert not cv.is_set() - def test_state_listener(self): - from kazoo.protocol.states import KazooState + def test_state_listener(self) -> None: states = [] condition = self.make_condition() - def listener(state): + def listener(state: KazooState) -> None: with condition: states.append(state) condition.notify_all() @@ -331,18 +344,17 @@ def listener(state): assert len(states) == 1 assert states[0] == KazooState.CONNECTED - def test_invalid_listener(self): + def test_invalid_listener(self) -> None: with pytest.raises(ConfigurationError): - self.client.add_listener(15) + self.client.add_listener(15) # type: ignore[arg-type] - def test_listener_only_called_on_real_state_change(self): - from kazoo.protocol.states import KazooState + def test_listener_only_called_on_real_state_change(self) -> None: assert self.client.state == KazooState.CONNECTED called = [False] condition = self.make_event() - def listener(state): + def listener(state: KazooState) -> None: called[0] = True condition.set() @@ -351,7 +363,7 @@ def listener(state): condition.wait(3) assert called[0] is False - def test_no_connection(self): + def test_no_connection(self) -> None: client = self.client client.stop() assert client.connected is False @@ -360,12 +372,12 @@ def test_no_connection(self): with pytest.raises(ConnectionClosedError): client.exists("/") - def test_close_connecting_connection(self): + def test_close_connecting_connection(self) -> None: client = self.client client.stop() ev = self.make_event() - def close_on_connecting(state): + def close_on_connecting(state: KazooState) -> None: if state in (KazooState.CONNECTED, KazooState.LOST): ev.set() @@ -385,23 +397,23 @@ def close_on_connecting(state): with pytest.raises(ConnectionClosedError): self.client.create("/foobar") - def test_double_start(self): + def test_double_start(self) -> None: assert self.client.connected is True self.client.start() assert self.client.connected is True - def test_double_stop(self): + def test_double_stop(self) -> None: self.client.stop() assert self.client.connected is False self.client.stop() assert self.client.connected is False - def test_restart(self): + def test_restart(self) -> None: assert self.client.connected is True self.client.restart() assert self.client.connected is True - def test_closed(self): + def test_closed(self) -> None: client = self.client client.stop() @@ -433,13 +445,13 @@ def test_closed(self): client._state = oldstate client._connection._write_sock = None - def test_watch_trigger_expire(self): + def test_watch_trigger_expire(self) -> None: client = self.client cv = self.make_event() client.create("/test", b"") - def test_watch(event): + def test_watch(event: WatchedEvent) -> None: cv.set() client.get("/test/", watch=test_watch) @@ -450,17 +462,11 @@ def test_watch(event): class TestClient(KazooTestCase): - def _makeOne(self, *args): - from kazoo.handlers.threading import SequentialThreadingHandler - + def _makeOne(self, *args: Any) -> SequentialThreadingHandler: return SequentialThreadingHandler(*args) - def _getKazooState(self): - from kazoo.protocol.states import KazooState + def test_server_version_retries_fail(self) -> None: - return KazooState - - def test_server_version_retries_fail(self): client = self.client side_effects = [ "", @@ -468,38 +474,38 @@ def test_server_version_retries_fail(self): "zookeeper.version=1.", "zookeeper.ver", ] - client.command = MagicMock() + client.command = MagicMock() # type: ignore[method-assign] client.command.side_effect = side_effects with pytest.raises(KazooException): client.server_version(retries=len(side_effects) - 1) - def test_server_version_retries_eventually_ok(self): + def test_server_version_retries_eventually_ok(self) -> None: client = self.client actual_version = "zookeeper.version=1.2" side_effects = [] for i in range(0, len(actual_version) + 1): side_effects.append(actual_version[0:i]) - client.command = MagicMock() + client.command = MagicMock() # type: ignore[method-assign] client.command.side_effect = side_effects assert client.server_version(retries=len(side_effects) - 1) == (1, 2) - def test_client_id(self): + def test_client_id(self) -> None: client_id = self.client.client_id assert type(client_id) is tuple # make sure password is of correct length assert len(client_id[1]) == 16 - def test_connected(self): + def test_connected(self) -> None: client = self.client assert client.connected - def test_create(self): + def test_create(self) -> None: client = self.client path = client.create("/1") assert path == "/1" assert client.exists("/1") - def test_create_on_broken_connection(self): + def test_create_on_broken_connection(self) -> None: client = self.client client.start() @@ -517,34 +523,34 @@ def test_create_on_broken_connection(self): with pytest.raises(ConnectionClosedError): client.create("/closedpath", b"bar") - def test_create_null_data(self): + def test_create_null_data(self) -> None: client = self.client client.create("/nulldata", None) value, _ = client.get("/nulldata") assert value is None - def test_create_empty_string(self): + def test_create_empty_string(self) -> None: client = self.client client.create("/empty", b"") value, _ = client.get("/empty") assert value == b"" - def test_create_unicode_path(self): + def test_create_unicode_path(self) -> None: client = self.client path = client.create("/ascii") assert path == "/ascii" path = client.create("/\xe4hm") assert path == "/\xe4hm" - def test_create_async_returns_unchrooted_path(self): + def test_create_async_returns_unchrooted_path(self) -> None: client = self.client path = client.create_async("/1").get() assert path == "/1" - def test_create_invalid_path(self): + def test_create_invalid_path(self) -> None: client = self.client with pytest.raises(TypeError): - client.create(("a",)) + client.create(("a",)) # type:ignore[call-overload] with pytest.raises(ValueError): client.create(".") with pytest.raises(ValueError): @@ -554,36 +560,34 @@ def test_create_invalid_path(self): with pytest.raises(BadArgumentsError): client.create("/b\x1e") - def test_create_invalid_arguments(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_create_invalid_arguments(self) -> None: single_acl = OPEN_ACL_UNSAFE[0] client = self.client with pytest.raises(TypeError): - client.create("a", acl="all") + client.create("a", acl="all") # type: ignore[arg-type] with pytest.raises(TypeError): - client.create("a", acl=single_acl) + client.create("a", acl=single_acl) # type: ignore[arg-type] with pytest.raises(TypeError): - client.create("a", value=["a"]) + client.create("a", value=["a"]) # type: ignore[call-overload] with pytest.raises(TypeError): - client.create("a", ephemeral="yes") + client.create("a", ephemeral="yes") # type: ignore[call-overload] with pytest.raises(TypeError): - client.create("a", sequence="yes") + client.create("a", sequence="yes") # type: ignore[call-overload] with pytest.raises(TypeError): - client.create("a", makepath="yes") + client.create("a", makepath="yes") # type: ignore[call-overload] - def test_create_value(self): + def test_create_value(self) -> None: client = self.client client.create("/1", b"bytes") data, stat = client.get("/1") assert data == b"bytes" - def test_create_unicode_value(self): + def test_create_unicode_value(self) -> None: client = self.client with pytest.raises(TypeError): - client.create("/1", "\xe4hm") + client.create("/1", "\xe4hm") # type: ignore[call-overload] - def test_create_large_value(self): + def test_create_large_value(self) -> None: client = self.client kb_512 = b"a" * (512 * 1024) client.create("/1", kb_512) @@ -592,49 +596,41 @@ def test_create_large_value(self): with pytest.raises(ConnectionLoss): client.create("/2", mb_2) - def test_create_acl_duplicate(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_create_acl_duplicate(self) -> None: single_acl = OPEN_ACL_UNSAFE[0] client = self.client client.create("/1", acl=[single_acl, single_acl]) acls, stat = client.get_acls("/1") # ZK >3.4 removes duplicate ACL entries - if CI_ZK_VERSION: - version = CI_ZK_VERSION - else: - version = client.server_version() + version = CI_ZK_VERSION if CI_ZK_VERSION else client.server_version() assert len(acls) == 1 if version > (3, 4) else 2 - def test_create_acl_empty_list(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_create_acl_empty_list(self) -> None: client = self.client client.create("/1", acl=[]) acls, stat = client.get_acls("/1") assert acls == OPEN_ACL_UNSAFE - def test_version_no_connection(self): + def test_version_no_connection(self) -> None: self.client.stop() with pytest.raises(ConnectionLoss): self.client.server_version() - def test_create_ephemeral(self): + def test_create_ephemeral(self) -> None: client = self.client client.create("/1", b"ephemeral", ephemeral=True) + assert client.client_id is not None data, stat = client.get("/1") assert data == b"ephemeral" assert stat.ephemeralOwner == client.client_id[0] - def test_create_no_ephemeral(self): + def test_create_no_ephemeral(self) -> None: client = self.client client.create("/1", b"val1") data, stat = client.get("/1") assert not stat.ephemeralOwner - def test_create_ephemeral_no_children(self): - from kazoo.exceptions import NoChildrenForEphemeralsError - + def test_create_ephemeral_no_children(self) -> None: client = self.client client.create("/1", b"ephemeral", ephemeral=True) with pytest.raises(NoChildrenForEphemeralsError): @@ -642,7 +638,7 @@ def test_create_ephemeral_no_children(self): with pytest.raises(NoChildrenForEphemeralsError): client.create("/1/2", b"val1", ephemeral=True) - def test_create_sequence(self): + def test_create_sequence(self) -> None: client = self.client client.create("/folder") path = client.create("/folder/a", b"sequence", sequence=True) @@ -652,7 +648,7 @@ def test_create_sequence(self): path3 = client.create("/folder/", b"sequence", sequence=True) assert path3 == "/folder/0000000002" - def test_create_ephemeral_sequence(self): + def test_create_ephemeral_sequence(self) -> None: basepath = "/" + uuid.uuid4().hex realpath = self.client.create( basepath, b"sandwich", sequence=True, ephemeral=True @@ -661,7 +657,7 @@ def test_create_ephemeral_sequence(self): data, stat = self.client.get(realpath) assert data == b"sandwich" - def test_create_makepath(self): + def test_create_makepath(self) -> None: self.client.create("/1/2", b"val1", makepath=True) data, stat = self.client.get("/1/2") assert data == b"val1" @@ -673,10 +669,7 @@ def test_create_makepath(self): with pytest.raises(NodeExistsError): self.client.create("/1/2/3/4/5", b"val2", makepath=True) - def test_create_makepath_incompatible_acls(self): - from kazoo.client import KazooClient - from kazoo.security import make_digest_acl_credential, CREATOR_ALL_ACL - + def test_create_makepath_incompatible_acls(self) -> None: credential = make_digest_acl_credential("username", "password") alt_client = KazooClient( self.cluster[0].address + self.client.chroot, @@ -695,7 +688,7 @@ def test_create_makepath_incompatible_acls(self): alt_client.delete("/", recursive=True) alt_client.stop() - def test_create_no_makepath(self): + def test_create_no_makepath(self) -> None: with pytest.raises(NoNodeError): self.client.create("/1/2", b"val1") with pytest.raises(NoNodeError): @@ -705,15 +698,13 @@ def test_create_no_makepath(self): with pytest.raises(NoNodeError): self.client.create("/1/2/3/4", b"val1", makepath=False) - def test_create_exists(self): - from kazoo.exceptions import NodeExistsError - + def test_create_exists(self) -> None: client = self.client path = client.create("/1") with pytest.raises(NodeExistsError): client.create(path) - def test_create_stat(self): + def test_create_stat(self) -> None: if CI_ZK_VERSION: version = CI_ZK_VERSION else: @@ -726,7 +717,7 @@ def test_create_stat(self): assert data == b"bytes" assert stat1 == stat2 - def test_create_get_set(self): + def test_create_get_set(self) -> None: nodepath = "/" + uuid.uuid4().hex self.client.create(nodepath, b"sandwich", ephemeral=True) @@ -749,20 +740,20 @@ def test_create_get_set(self): assert newstat.children_count == stat.numChildren assert newstat.children_version == stat.cversion - def test_get_invalid_arguments(self): + def test_get_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.get(("a", "b")) + client.get(("a", "b")) # type: ignore[arg-type] with pytest.raises(TypeError): - client.get("a", watch=True) + client.get("a", watch=True) # type: ignore[arg-type] - def test_bad_argument(self): + def test_bad_argument(self) -> None: client = self.client client.ensure_path("/1") with pytest.raises(TypeError): - self.client.set("/1", 1) + self.client.set("/1", 1) # type: ignore[arg-type] - def test_ensure_path(self): + def test_ensure_path(self) -> None: client = self.client client.ensure_path("/1/2") assert client.exists("/1/2") @@ -770,13 +761,13 @@ def test_ensure_path(self): client.ensure_path("/1/2/3/4") assert client.exists("/1/2/3/4") - def test_sync(self): + def test_sync(self) -> None: client = self.client assert client.sync("/") == "/" # Albeit surprising, you can sync anything, even what does not exist. assert client.sync("/not_there") == "/not_there" - def test_exists(self): + def test_exists(self) -> None: nodepath = "/" + uuid.uuid4().hex exists = self.client.exists(nodepath) @@ -791,18 +782,18 @@ def test_exists(self): exists = self.client.exists(multi_node_nonexistent) assert exists is None - def test_exists_invalid_arguments(self): + def test_exists_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.exists(("a", "b")) + client.exists(("a", "b")) # type: ignore[arg-type] with pytest.raises(TypeError): - client.exists("a", watch=True) + client.exists("a", watch=True) # type: ignore[arg-type] - def test_exists_watch(self): + def test_exists_watch(self) -> None: nodepath = "/" + uuid.uuid4().hex event = self.client.handler.event_object() - def w(watch_event): + def w(watch_event: WatchedEvent) -> None: assert watch_event.path == nodepath event.set() @@ -814,12 +805,12 @@ def w(watch_event): event.wait(1) assert event.is_set() is True - def test_exists_watcher_exception(self): + def test_exists_watcher_exception(self) -> None: nodepath = "/" + uuid.uuid4().hex event = self.client.handler.event_object() # if the watcher throws an exception, all we can really do is log it - def w(watch_event): + def w(watch_event: WatchedEvent) -> None: assert watch_event.path == nodepath event.set() @@ -833,7 +824,7 @@ def w(watch_event): event.wait(1) assert event.is_set() is True - def test_create_delete(self): + def test_create_delete(self) -> None: nodepath = "/" + uuid.uuid4().hex self.client.create(nodepath, b"zzz") @@ -843,9 +834,7 @@ def test_create_delete(self): exists = self.client.exists(nodepath) assert exists is None - def test_get_acls(self): - from kazoo.security import make_digest_acl - + def test_get_acls(self) -> None: user = "user" passw = "pass" acl = make_digest_acl(user, passw, all=True) @@ -857,14 +846,12 @@ def test_get_acls(self): finally: client.delete("/a") - def test_get_acls_invalid_arguments(self): + def test_get_acls_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.get_acls(("a", "b")) - - def test_set_acls(self): - from kazoo.security import make_digest_acl + client.get_acls(("a", "b")) # type: ignore[arg-type] + def test_set_acls(self) -> None: user = "user" passw = "pass" acl = make_digest_acl(user, passw, all=True) @@ -877,34 +864,30 @@ def test_set_acls(self): finally: client.delete("/a") - def test_set_acls_empty(self): + def test_set_acls_empty(self) -> None: client = self.client client.create("/a") with pytest.raises(InvalidACLError): client.set_acls("/a", []) - def test_set_acls_no_node(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_set_acls_no_node(self) -> None: client = self.client with pytest.raises(NoNodeError): client.set_acls("/a", OPEN_ACL_UNSAFE) - def test_set_acls_invalid_arguments(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_set_acls_invalid_arguments(self) -> None: single_acl = OPEN_ACL_UNSAFE[0] client = self.client with pytest.raises(TypeError): - client.set_acls(("a", "b"), ()) + client.set_acls(("a", "b"), ()) # type: ignore[arg-type] with pytest.raises(TypeError): - client.set_acls("a", single_acl) + client.set_acls("a", single_acl) # type: ignore[arg-type] with pytest.raises(TypeError): - client.set_acls("a", "all") + client.set_acls("a", "all") # type: ignore[arg-type] with pytest.raises(TypeError): - client.set_acls("a", [single_acl], "V1") + client.set_acls("a", [single_acl], "V1") # type: ignore[arg-type] - def test_set(self): + def test_set(self) -> None: client = self.client client.create("a", b"first") stat = client.set("a", b"second") @@ -912,38 +895,38 @@ def test_set(self): assert data == b"second" assert stat == stat2 - def test_set_null_data(self): + def test_set_null_data(self) -> None: client = self.client client.create("/nulldata", b"not none") client.set("/nulldata", None) value, _ = client.get("/nulldata") assert value is None - def test_set_empty_string(self): + def test_set_empty_string(self) -> None: client = self.client client.create("/empty", b"not empty") client.set("/empty", b"") value, _ = client.get("/empty") assert value == b"" - def test_set_invalid_arguments(self): + def test_set_invalid_arguments(self) -> None: client = self.client client.create("a", b"first") with pytest.raises(TypeError): - client.set(("a", "b"), b"value") + client.set(("a", "b"), b"value") # type: ignore[arg-type] with pytest.raises(TypeError): - client.set("a", ["v", "w"]) + client.set("a", ["v", "w"]) # type: ignore[arg-type] with pytest.raises(TypeError): - client.set("a", b"value", "V1") + client.set("a", b"value", "V1") # type: ignore[arg-type] - def test_delete(self): + def test_delete(self) -> None: client = self.client client.ensure_path("/a/b") assert "b" in client.get_children("a") client.delete("/a/b") assert "b" not in client.get_children("a") - def test_delete_recursive(self): + def test_delete_recursive(self) -> None: client = self.client client.ensure_path("/a/b/c") client.ensure_path("/a/b/d") @@ -951,17 +934,17 @@ def test_delete_recursive(self): client.delete("/a/b/c", recursive=True) assert "b" not in client.get_children("a") - def test_delete_invalid_arguments(self): + def test_delete_invalid_arguments(self) -> None: client = self.client client.ensure_path("/a/b") with pytest.raises(TypeError): - client.delete("/a/b", recursive="all") + client.delete("/a/b", recursive="all") # type: ignore[arg-type] with pytest.raises(TypeError): - client.delete(("a", "b")) + client.delete(("a", "b")) # type: ignore[arg-type] with pytest.raises(TypeError): - client.delete("/a/b", version="V1") + client.delete("/a/b", version="V1") # type: ignore[arg-type] - def test_get_children(self): + def test_get_children(self) -> None: client = self.client client.ensure_path("/a/b/c") client.ensure_path("/a/b/d") @@ -969,7 +952,7 @@ def test_get_children(self): assert set(client.get_children("/a/b")) == set(["c", "d"]) assert client.get_children("/a/b/c") == [] - def test_get_children2(self): + def test_get_children2(self) -> None: client = self.client client.ensure_path("/a/b") children, stat = client.get_children("/a", include_data=True) @@ -977,7 +960,7 @@ def test_get_children2(self): assert children == ["b"] assert stat2.version == stat.version - def test_get_children2_many_nodes(self): + def test_get_children2_many_nodes(self) -> None: client = self.client client.ensure_path("/a/b") client.ensure_path("/a/c") @@ -987,31 +970,30 @@ def test_get_children2_many_nodes(self): assert set(children) == set(["b", "c", "d"]) assert stat2.version == stat.version - def test_get_children_no_node(self): + def test_get_children_no_node(self) -> None: client = self.client with pytest.raises(NoNodeError): client.get_children("/none") with pytest.raises(NoNodeError): client.get_children("/none", include_data=True) - def test_get_children_invalid_path(self): + def test_get_children_invalid_path(self) -> None: client = self.client with pytest.raises(ValueError): client.get_children("../a") - def test_get_children_invalid_arguments(self): + def test_get_children_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.get_children(("a", "b")) + client.get_children(("a", "b")) # type: ignore[call-overload] with pytest.raises(TypeError): - client.get_children("a", watch=True) + client.get_children("a", watch=True) # type: ignore[call-overload] with pytest.raises(TypeError): - client.get_children("a", include_data="yes") - - def test_invalid_auth(self): - from kazoo.exceptions import AuthFailedError - from kazoo.protocol.states import KeeperState + client.get_children( # type: ignore[call-overload] + "a", include_data="yes" + ) + def test_invalid_auth(self) -> None: client = self.client client.stop() client._state = KeeperState.AUTH_FAILED @@ -1019,15 +1001,10 @@ def test_invalid_auth(self): with pytest.raises(AuthFailedError): client.get("/") - def test_client_state(self): - from kazoo.protocol.states import KeeperState - + def test_client_state(self) -> None: assert self.client.client_state == KeeperState.CONNECTED - def test_update_host_list(self): - from kazoo.client import KazooClient - from kazoo.protocol.states import KeeperState - + def test_update_host_list(self) -> None: hosts = self.cluster[0].address # create a client with only one server in its list client = KazooClient(hosts=hosts) @@ -1049,8 +1026,9 @@ def test_update_host_list(self): self.cluster[0].run() # utility for test_request_queuing* - def _make_request_queuing_client(self): - from kazoo.client import KazooClient + def _make_request_queuing_client( + self, + ) -> tuple[KazooClient, ManagedZooKeeper]: server = self.cluster[0] handler = self._makeOne() @@ -1071,11 +1049,17 @@ def _make_request_queuing_client(self): return client, server # utility for test_request_queuing* - def _request_queuing_common(self, client, server, path, expire_session): + def _request_queuing_common( + self, + client: KazooClient, + server: ManagedZooKeeper, + path: str, + expire_session: bool, + ) -> IAsyncResult: ev_suspended = client.handler.event_object() ev_connected = client.handler.event_object() - def listener(state): + def listener(state: KazooState) -> None: if state == KazooState.SUSPENDED: ev_suspended.set() elif state == KazooState.CONNECTED: @@ -1117,7 +1101,7 @@ def listener(state): return result - def test_request_queuing_session_recovered(self): + def test_request_queuing_session_recovered(self) -> None: path = "/" + uuid.uuid4().hex client, server = self._make_request_queuing_client() @@ -1131,7 +1115,7 @@ def test_request_queuing_session_recovered(self): finally: client.stop() - def test_request_queuing_session_expired(self): + def test_request_queuing_session_expired(self) -> None: path = "/" + uuid.uuid4().hex client, server = self._make_request_queuing_client() @@ -1148,7 +1132,7 @@ def test_request_queuing_session_expired(self): class TestSSLClient(KazooTestCase): - def setUp(self): + def setUp(self) -> None: if CI_ZK_VERSION and CI_ZK_VERSION < (3, 5): pytest.skip("Must use Zookeeper 3.5 or above") ssl_path = tempfile.mkdtemp() @@ -1171,7 +1155,7 @@ def setUp(self): use_ssl=True, keyfile=key_path, certfile=cert_path, ca=cacert_path ) - def test_create(self): + def test_create(self) -> None: client = self.client path = client.create("/1") assert path == "/1" @@ -1194,7 +1178,7 @@ def test_create(self): class TestClientTransactions(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) skip = False if CI_ZK_VERSION and CI_ZK_VERSION < (3, 4): @@ -1208,7 +1192,7 @@ def setUp(self): if skip: pytest.skip("Must use Zookeeper 3.4 or above") - def test_basic_create(self): + def test_basic_create(self) -> None: t = self.client.transaction() t.create("/freddy") t.create("/fred", ephemeral=True) @@ -1218,7 +1202,7 @@ def test_basic_create(self): assert results[0] == "/freddy" assert results[2].startswith("/smith0") is True - def test_bad_creates(self): + def test_bad_creates(self) -> None: args_list = [ (True,), ("/smith", 0), @@ -1230,11 +1214,9 @@ def test_bad_creates(self): for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.create(*args) - - def test_default_acl(self): - from kazoo.security import make_digest_acl + t.create(*args) # type: ignore[arg-type] + def test_default_acl(self) -> None: username = uuid.uuid4().hex password = uuid.uuid4().hex @@ -1249,14 +1231,14 @@ def test_default_acl(self): results = t.commit() assert results[0] == "/freddy" - def test_basic_delete(self): + def test_basic_delete(self) -> None: self.client.create("/fred") t = self.client.transaction() t.delete("/fred") results = t.commit() assert results[0] is True - def test_bad_deletes(self): + def test_bad_deletes(self) -> None: args_list = [ (True,), ("/smith", "woops"), @@ -1265,9 +1247,9 @@ def test_bad_deletes(self): for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.delete(*args) + t.delete(*args) # type: ignore[arg-type] - def test_set(self): + def test_set(self) -> None: self.client.create("/fred", b"01") t = self.client.transaction() t.set_data("/fred", b"oops") @@ -1275,15 +1257,15 @@ def test_set(self): res = self.client.get("/fred") assert res[0] == b"oops" - def test_bad_sets(self): + def test_bad_sets(self) -> None: args_list = [(42, 52), ("/smith", False), ("/smith", b"", "oops")] for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.set_data(*args) + t.set_data(*args) # type: ignore[arg-type] - def test_check(self): + def test_check(self) -> None: self.client.create("/fred") version = self.client.get("/fred")[1].version t = self.client.transaction() @@ -1293,17 +1275,15 @@ def test_check(self): assert results[0] is True assert results[1] == "/blah" - def test_bad_checks(self): + def test_bad_checks(self) -> None: args_list = [(42, 52), ("/smith", "oops")] for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.check(*args) - - def test_bad_transaction(self): - from kazoo.exceptions import RolledBackError, NoNodeError + t.check(*args) # type: ignore[arg-type] + def test_bad_transaction(self) -> None: t = self.client.transaction() t.create("/fred") t.delete("/smith") @@ -1311,57 +1291,56 @@ def test_bad_transaction(self): assert results[0].__class__ == RolledBackError assert results[1].__class__ == NoNodeError - def test_bad_commit(self): + def test_bad_commit(self) -> None: t = self.client.transaction() t.committed = True with pytest.raises(ValueError): t.commit() - def test_bad_context(self): + def test_bad_context(self) -> None: with pytest.raises(TypeError): with self.client.transaction() as t: - t.check(4232) + t.check(4232) # type: ignore[arg-type,call-arg] - def test_context(self): + def test_context(self) -> None: with self.client.transaction() as t: t.create("/smith", b"32") assert self.client.get("/smith")[0] == b"32" class TestSessionCallbacks(unittest.TestCase): - def test_session_callback_states(self): - from kazoo.protocol.states import KazooState, KeeperState - from kazoo.client import KazooClient - + def test_session_callback_states(self) -> None: client = KazooClient() - client._handle = 1 client._live.set() - result = client._session_callback(KeeperState.CONNECTED) - assert result is None + client._session_callback(KeeperState.CONNECTED) # Now with stopped client._stopped.set() - result = client._session_callback(KeeperState.CONNECTED) - assert result is None + client._session_callback(KeeperState.CONNECTED) # Test several state transitions client._stopped.clear() - client.start_async = lambda: True + client.start_async = ( # type: ignore[method-assign] + lambda: threading.Event() + ) client._session_callback(KeeperState.CONNECTED) assert client.state == KazooState.CONNECTED client._session_callback(KeeperState.AUTH_FAILED) - assert client.state == KazooState.LOST + # FIXME mypy seems to be under the impression that the state can't + # change as a result of the above call, even though it can. + assert ( + client.state == KazooState.LOST # type: ignore[comparison-overlap] + ) - client._handle = 1 - client._session_callback(-250) + client._session_callback(-250) # type: ignore[unreachable] assert client.state == KazooState.SUSPENDED class TestCallbacks(KazooTestCase): - def test_async_result_callbacks_are_always_called(self): + def test_async_result_callbacks_are_always_called(self) -> None: # create a callback object callback_mock = Mock() @@ -1384,7 +1363,7 @@ def test_async_result_callbacks_are_always_called(self): class TestNonChrootClient(KazooTestCase): - def test_create(self): + def test_create(self) -> None: client = self._get_nonchroot_client() assert client.chroot == "" client.start() @@ -1393,7 +1372,7 @@ def test_create(self): client.delete(path) client.stop() - def test_unchroot(self): + def test_unchroot(self) -> None: client = self._get_nonchroot_client() client.chroot = "/a" # Unchroot'ing the chroot path should return "/" @@ -1403,7 +1382,7 @@ def test_unchroot(self): class TestReconfig(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) if CI_ZK_VERSION: @@ -1413,7 +1392,7 @@ def setUp(self): if not version or version < (3, 5): pytest.skip("Must use Zookeeper 3.5 or above") - def test_no_super_auth(self): + def test_no_super_auth(self) -> None: with pytest.raises(NoAuthError): self.client.reconfig( joining="server.999=0.0.0.0:1234:2345:observer;3456", @@ -1421,8 +1400,8 @@ def test_no_super_auth(self): new_members=None, ) - def test_add_remove_observer(self): - def free_sock_port(): + def test_add_remove_observer(self) -> None: + def free_sock_port() -> tuple[socket.socket, int]: s = socket.socket() s.bind(("", 0)) return s, s.getsockname()[1] @@ -1466,7 +1445,7 @@ def free_sock_port(): from_config=curver + 1, ) - def test_bad_input(self): + def test_bad_input(self) -> None: with pytest.raises(BadArgumentsError): self.client.reconfig( joining="some thing", leaving=None, new_members=None diff --git a/kazoo/tests/test_connection.py b/kazoo/tests/test_connection.py index 032b94bb..0579a91b 100644 --- a/kazoo/tests/test_connection.py +++ b/kazoo/tests/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import namedtuple, deque import os import threading @@ -7,15 +9,18 @@ import struct import sys +from typing import Any, Iterable, Deque, Tuple import pytest -from kazoo.exceptions import ConnectionLoss +from kazoo.client import KazooClient +from kazoo.exceptions import ConnectionLoss, NotReadOnlyCallError +from kazoo.interfaces import FdLike from kazoo.protocol.serialization import ( Connect, int_struct, write_string, ) -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, KeeperState from kazoo.protocol.connection import _CONNECTION_DROP from kazoo.testing import KazooTestCase from kazoo.tests.util import wait, CI_ZK_VERSION, CI @@ -24,32 +29,33 @@ class Delete(namedtuple("Delete", "path version")): type = 2 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b @classmethod - def deserialize(self, bytes, offset): + def deserialize(self, bytes: bytes, offset: int) -> None: raise ValueError("oh my") class TestConnectionHandler(KazooTestCase): - def test_bad_deserialization(self): + def test_bad_deserialization(self) -> None: async_object = self.client.handler.async_result() self.client._queue.append( (Delete(self.client.chroot, -1), async_object) ) + assert self.client._connection._write_sock is not None self.client._connection._write_sock.send(b"\0") with pytest.raises(ValueError): async_object.get() - def test_with_bad_sessionid(self): + def test_with_bad_sessionid(self) -> None: ev = threading.Event() - def expired(state): + def expired(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() @@ -63,7 +69,7 @@ def expired(state): finally: client.stop() - def test_connection_read_timeout(self): + def test_connection_read_timeout(self) -> None: client = self.client ev = threading.Event() path = "/" + uuid.uuid4().hex @@ -71,31 +77,33 @@ def test_connection_read_timeout(self): _select = handler.select _socket = client._connection._socket - def delayed_select(*args, **kwargs): + def delayed_select( + *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: result = _select(*args, **kwargs) if len(args[0]) == 1 and _socket in args[0]: # for any socket read, simulate a timeout return [], [], [] return result - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() client.add_listener(back) client.create(path, b"1") try: - handler.select = delayed_select + handler.select = delayed_select # type: ignore[method-assign] with pytest.raises(ConnectionLoss): client.get(path) finally: - handler.select = _select + handler.select = _select # type: ignore[method-assign] # the client reconnects automatically ev.wait(5) assert ev.is_set() assert client.get(path)[0] == b"1" - def test_connection_write_timeout(self): + def test_connection_write_timeout(self) -> None: client = self.client ev = threading.Event() path = "/" + uuid.uuid4().hex @@ -103,31 +111,33 @@ def test_connection_write_timeout(self): _select = handler.select _socket = client._connection._socket - def delayed_select(*args, **kwargs): + def delayed_select( + *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: result = _select(*args, **kwargs) if _socket in args[1]: # for any socket write, simulate a timeout return [], [], [] return result - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() client.add_listener(back) try: - handler.select = delayed_select + handler.select = delayed_select # type: ignore[method-assign] with pytest.raises(ConnectionLoss): client.create(path) finally: - handler.select = _select + handler.select = _select # type: ignore[method-assign] # the client reconnects automatically ev.wait(5) assert ev.is_set() assert client.exists(path) is None - def test_connection_deserialize_fail(self): + def test_connection_deserialize_fail(self) -> None: client = self.client ev = threading.Event() path = "/" + uuid.uuid4().hex @@ -135,14 +145,16 @@ def test_connection_deserialize_fail(self): _select = handler.select _socket = client._connection._socket - def delayed_select(*args, **kwargs): + def delayed_select( + *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: result = _select(*args, **kwargs) if _socket in args[1]: # for any socket write, simulate a timeout return [], [], [] return result - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() @@ -150,7 +162,7 @@ def back(state): deserialize_ev = threading.Event() - def bad_deserialize(_bytes, offset): + def bad_deserialize(_bytes: bytes, offset: int) -> None: deserialize_ev.set() raise struct.error() @@ -162,11 +174,11 @@ def bad_deserialize(_bytes, offset): with patch.object(Connect, "deserialize") as mock_deserialize: mock_deserialize.side_effect = bad_deserialize try: - handler.select = delayed_select + handler.select = delayed_select # type: ignore[method-assign] with pytest.raises(ConnectionLoss): client.create(path) finally: - handler.select = _select + handler.select = _select # type: ignore[method-assign] # the client reconnects automatically but the first attempt will # hit a deserialize failure. wait for that. deserialize_ev.wait(5) @@ -177,7 +189,7 @@ def bad_deserialize(_bytes, offset): assert ev.is_set() assert client.exists(path) is None - def test_connection_close(self): + def test_connection_close(self) -> None: with pytest.raises(Exception): self.client.close() self.client.stop() @@ -186,7 +198,7 @@ def test_connection_close(self): # should be able to restart self.client.start() - def test_connection_sock(self): + def test_connection_sock(self) -> None: client = self.client read_sock = client._connection._read_sock write_sock = client._connection._write_sock @@ -217,10 +229,12 @@ def test_connection_sock(self): read_sock.getsockname() write_sock.getsockname() - def test_dirty_sock(self): + def test_dirty_sock(self) -> None: client = self.client read_sock = client._connection._read_sock write_sock = client._connection._write_sock + assert read_sock is not None + assert write_sock is not None # add a stray byte to the socket and ensure that doesn't # blow up client. simulates case where some error leaves @@ -233,10 +247,10 @@ def test_dirty_sock(self): class TestConnectionDrop(KazooTestCase): - def test_connection_dropped(self): + def test_connection_dropped(self) -> None: ev = threading.Event() - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() @@ -245,7 +259,7 @@ def back(state): self.client.create(path) self.client.add_listener(back) result = self.client.set_async(path, b"a" * 1000 * 1024) - self.client._call(_CONNECTION_DROP, None) + self.client._call(_CONNECTION_DROP, None) # type: ignore[arg-type] with pytest.raises(ConnectionLoss): result.get() @@ -255,7 +269,7 @@ def back(state): class TestReadOnlyMode(KazooTestCase): - def setUp(self): + def setUp(self) -> None: os.environ["ZOOKEEPER_LOCAL_SESSION_RO"] = "true" self.setup_zookeeper() skip = False @@ -270,31 +284,28 @@ def setUp(self): if skip: pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.client.stop() os.environ.pop("ZOOKEEPER_LOCAL_SESSION_RO", None) - def test_read_only(self): - from kazoo.exceptions import NotReadOnlyCallError - from kazoo.protocol.states import KeeperState - + def test_read_only(self) -> None: if CI: # force some wait to make sure the data produced during the - # `setUp()` step are replicaed to all zk members + # `setUp()` step are replicated to all zk members # if not done the `get_children()` test may fail because the # node does not exist on the node that we will keep alive time.sleep(15) # do not keep the client started in the `setUp` step alive self.client.stop() client = self._get_client(connection_retry=None, read_only=True) - states = [] ev = threading.Event() - @client.add_listener - def listen(state): - states.append(state) + def listen(state: KazooState) -> bool | None: if client.client_state == KeeperState.CONNECTED_RO: ev.set() + return None + + client.add_listener(listen) client.start() try: @@ -336,7 +347,7 @@ def listen(state): class TestUnorderedXids(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(TestUnorderedXids, self).setUp() self.connection = self.client._connection @@ -345,46 +356,53 @@ def setUp(self): self._pending = self.client._pending self.client._pending = _naughty_deque() - def tearDown(self): + def tearDown(self) -> None: self.client._pending = self._pending super(TestUnorderedXids, self).tearDown() - def _get_client(self, **kwargs): + def _get_client(self, **kwargs: Any) -> KazooClient: # overrides for patching zk_loop c = KazooTestCase._get_client(self, **kwargs) self._zk_loop = c._connection.zk_loop - self._zk_loop_errors = [] - c._connection.zk_loop = self._zk_loop_func + self._zk_loop_errors: list[BaseException] = [] + c._connection.zk_loop = ( # type: ignore[method-assign] + self._zk_loop_func + ) return c - def _zk_loop_func(self, *args, **kwargs): + def _zk_loop_func(self, *args: Any, **kwargs: Any) -> None: # patched zk_loop which will catch and collect all RuntimeError try: self._zk_loop(*args, **kwargs) except RuntimeError as e: self._zk_loop_errors.append(e) - def test_xids_mismatch(self): + def test_xids_mismatch(self) -> None: from kazoo.protocol.states import KeeperState ev = threading.Event() error_stack = [] - @self.client.add_listener - def listen(state): + def listen(state: KazooState) -> bool | None: if self.client.client_state == KeeperState.CLOSED: ev.set() + return None - def log_exception(*args): + self.client.add_listener(listen) + + def log_exception(*args: Any) -> None: error_stack.append((args, sys.exc_info())) - self.connection.logger.exception = log_exception + self.connection.logger.exception = ( # type: ignore[method-assign] + log_exception # type: ignore[assignment] + ) ev.clear() with pytest.raises(RuntimeError): self.client.get_children("/") ev.wait() + self.client.remove_listener(listen) assert self.client.connected is False assert self.client.state == "LOST" assert self.client.client_state == KeeperState.CLOSED @@ -394,12 +412,13 @@ def log_exception(*args): assert exc_info[0] == RuntimeError self.client.handler.sleep_func(0.2) + assert self.connection_routine is not None assert not self.connection_routine.is_alive() assert len(self._zk_loop_errors) == 1 assert self._zk_loop_errors[0] == exc_info[1] -class _naughty_deque(deque): - def append(self, s): +class _naughty_deque(Deque[Tuple[Any, Any, int]]): + def append(self, s: Tuple[Any, Any, int]) -> None: request, async_object, xid = s - return deque.append(self, (request, async_object, xid + 1)) # +1s + deque.append(self, (request, async_object, xid + 1)) # +1s diff --git a/kazoo/tests/test_counter.py b/kazoo/tests/test_counter.py index a7866735..aff4012b 100644 --- a/kazoo/tests/test_counter.py +++ b/kazoo/tests/test_counter.py @@ -1,16 +1,21 @@ +from __future__ import annotations + import uuid +from typing import Any + import pytest +from kazoo.recipe.counter import Counter from kazoo.testing import KazooTestCase class KazooCounterTests(KazooTestCase): - def _makeOne(self, **kw): + def _makeOne(self, **kw: Any) -> Counter: path = "/" + uuid.uuid4().hex return self.client.Counter(path, **kw) - def test_int_counter(self): + def test_int_counter(self) -> None: counter = self._makeOne() assert counter.value == 0 counter += 2 @@ -20,7 +25,7 @@ def test_int_counter(self): counter - 1 assert counter.value == -1 - def test_int_curator_counter(self): + def test_int_curator_counter(self) -> None: counter = self._makeOne(support_curator=True) assert counter.value == 0 counter += 2 @@ -36,7 +41,7 @@ def test_int_curator_counter(self): counter -= 2147483647 assert counter.value == -2147483647 - def test_float_counter(self): + def test_float_counter(self) -> None: counter = self._makeOne(default=0.0) assert counter.value == 0.0 counter += 2.1 @@ -44,16 +49,18 @@ def test_float_counter(self): counter -= 3.1 assert counter.value == -1.0 - def test_errors(self): + def test_errors(self) -> None: counter = self._makeOne() with pytest.raises(TypeError): - counter.__add__(2.1) + counter.__add__(2.1) # type: ignore[arg-type] with pytest.raises(TypeError): - counter.__add__(b"a") + counter.__add__(b"a") # type: ignore[operator] with pytest.raises(TypeError): - counter = self._makeOne(default=0.0, support_curator=True) + counter = self._makeOne( # type: ignore[arg-type] + default=0.0, support_curator=True + ) - def test_pre_post_values(self): + def test_pre_post_values(self) -> None: counter = self._makeOne() assert counter.value == 0 assert counter.pre_value is None diff --git a/kazoo/tests/test_election.py b/kazoo/tests/test_election.py index c7b1b595..d2cb1e46 100644 --- a/kazoo/tests/test_election.py +++ b/kazoo/tests/test_election.py @@ -1,19 +1,27 @@ +from __future__ import annotations + import uuid import sys import threading +from typing import TYPE_CHECKING, cast import pytest +from kazoo.recipe.election import Election from kazoo.testing import KazooTestCase from kazoo.tests.util import wait +if TYPE_CHECKING: + from types import TracebackType + from _typeshed import OptExcInfo + class UniqueError(Exception): """Error raised only by test leader function""" class KazooElectionTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooElectionTests, self).setUp() self.path = "/" + uuid.uuid4().hex @@ -21,17 +29,19 @@ def setUp(self): # election contenders set these when elected. The exit event is set by # the test to make the leader exit. - self.leader_id = None - self.exit_event = None + self.leader_id: str | None = None + self.exit_event: threading.Event | None = None # tests set this before the event to make the leader raise an error self.raise_exception = False # set by a worker thread when an unexpected error is hit. # better way to do this? - self.thread_exc_info = None + self.thread_exc_info: OptExcInfo | None = None - def _spawn_contender(self, contender_id, election): + def _spawn_contender( + self, contender_id: str, election: Election + ) -> threading.Thread: thread = threading.Thread( target=self._election_thread, args=(contender_id, election) ) @@ -39,7 +49,7 @@ def _spawn_contender(self, contender_id, election): thread.start() return thread - def _election_thread(self, contender_id, election): + def _election_thread(self, contender_id: str, election: Election) -> None: try: election.run(self._leader_func, contender_id) except UniqueError: @@ -50,9 +60,13 @@ def _election_thread(self, contender_id, election): else: if self.raise_exception: e = Exception("expected leader func to raise exception") - self.thread_exc_info = (Exception, e, None) + self.thread_exc_info = ( + Exception, + e, + cast("TracebackType", None), + ) - def _leader_func(self, name): + def _leader_func(self, name: str) -> None: exit_event = threading.Event() with self.condition: self.exit_event = exit_event @@ -63,12 +77,13 @@ def _leader_func(self, name): if self.raise_exception: raise UniqueError("expected error in the leader function") - def _check_thread_error(self): - if self.thread_exc_info: + def _check_thread_error(self) -> None: + if self.thread_exc_info is not None: t, o, tb = self.thread_exc_info + assert t is not None raise t(o) - def test_election(self): + def test_election(self) -> None: elections = {} threads = {} for _ in range(3): @@ -105,6 +120,7 @@ def test_election(self): elections[contenders[1]].cancel() # make leader exit. third contender should be elected. + assert self.exit_event is not None self.exit_event.set() with self.condition: while self.leader_id == first_leader: @@ -137,7 +153,8 @@ def test_election(self): thread.join() self._check_thread_error() - def test_bad_func(self): + def test_bad_func(self) -> None: election = self.client.Election(self.path) + # FIXME If we're using type hints, we don't need to check this. with pytest.raises(ValueError): - election.run("not a callable") + election.run("not a callable") # type: ignore[arg-type] diff --git a/kazoo/tests/test_eventlet_handler.py b/kazoo/tests/test_eventlet_handler.py index ff2649d9..08db6768 100644 --- a/kazoo/tests/test_eventlet_handler.py +++ b/kazoo/tests/test_eventlet_handler.py @@ -1,27 +1,28 @@ +from __future__ import annotations + import contextlib import unittest +from typing import Generator, Literal + import pytest -from kazoo.client import KazooClient +from kazoo.handlers.utils import create_tcp_socket from kazoo.handlers import utils from kazoo.protocol import states as kazoo_states -from kazoo.tests import test_client -from kazoo.tests import test_lock -from kazoo.tests import util as test_util try: - import eventlet - from eventlet.green import threading + from eventlet.green import socket from kazoo.handlers import eventlet as eventlet_handler - - EVENTLET_HANDLER_AVAILABLE = True + from kazoo.handlers.eventlet import SequentialEventletHandler except ImportError: - EVENTLET_HANDLER_AVAILABLE = False + pytestmark = pytest.mark.skip(reason="eventlet not available") @contextlib.contextmanager -def start_stop_one(handler=None): +def start_stop_one( + handler: SequentialEventletHandler = None, # type: ignore[assignment] +) -> Generator[SequentialEventletHandler]: if not handler: handler = eventlet_handler.SequentialEventletHandler() handler.start() @@ -32,22 +33,17 @@ def start_stop_one(handler=None): class TestEventletHandler(unittest.TestCase): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletHandler, self).setUp() - - def test_started(self): + def test_started(self) -> None: with start_stop_one() as handler: assert handler.running is True assert len(handler._workers) != 0 assert handler.running is False - assert len(handler._workers) == 0 + assert len(handler._workers) == 0 # type: ignore[unreachable] - def test_spawn(self): + def test_spawn(self) -> None: captures = [] - def cb(): + def cb() -> None: captures.append(1) with start_stop_one() as handler: @@ -55,10 +51,10 @@ def cb(): assert len(captures) == 1 - def test_dispatch(self): + def test_dispatch(self) -> None: captures = [] - def cb(): + def cb() -> None: captures.append(1) with start_stop_one() as handler: @@ -66,10 +62,10 @@ def cb(): assert len(captures) == 1 - def test_async_link(self): - captures = [] + def test_async_link(self) -> None: + captures: list[SequentialEventletHandler] = [] - def cb(handler): + def cb(handler: SequentialEventletHandler) -> None: captures.append(handler) with start_stop_one() as handler: @@ -80,20 +76,20 @@ def cb(handler): assert len(captures) == 1 assert r.get() == 2 - def test_timeout_raising(self): + def test_timeout_raising(self) -> None: handler = eventlet_handler.SequentialEventletHandler() with pytest.raises(handler.timeout_exception): raise handler.timeout_exception("This is a timeout") - def test_async_ok(self): - captures = [] + def test_async_ok(self) -> None: + captures: list[Literal[1] | SequentialEventletHandler] = [] - def delayed(): + def delayed() -> Literal[1]: captures.append(1) return 1 - def after_delayed(handler): + def after_delayed(handler: SequentialEventletHandler) -> None: captures.append(handler) with start_stop_one() as handler: @@ -106,7 +102,7 @@ def after_delayed(handler): assert captures[0] == 1 assert r.get() == 1 - def test_get_with_no_block(self): + def test_get_with_no_block(self) -> None: handler = eventlet_handler.SequentialEventletHandler() with start_stop_one(handler): @@ -117,8 +113,8 @@ def test_get_with_no_block(self): r.set(1) assert r.get() == 1 - def test_async_exception(self): - def broken(): + def test_async_exception(self) -> None: + def broken() -> None: raise IOError("Failed") with start_stop_one() as handler: @@ -130,13 +126,11 @@ def broken(): with pytest.raises(IOError): r.get() - def test_huge_file_descriptor(self): + def test_huge_file_descriptor(self) -> None: try: import resource except ImportError: self.skipTest("resource module unavailable on this platform") - from eventlet.green import socket - from kazoo.handlers.utils import create_tcp_socket try: resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096)) @@ -154,91 +148,3 @@ def test_huge_file_descriptor(self): h.stop() for sock in socks: sock.close() - - -class TestEventletClient(test_client.TestClient): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletClient, self).setUp() - - @staticmethod - def make_event(): - return threading.Event() - - @staticmethod - def make_condition(): - return threading.Condition() - - def _makeOne(self, *args): - return eventlet_handler.SequentialEventletHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - return KazooClient(self.hosts, **kwargs) - - -class TestEventletSemaphore(test_lock.TestSemaphore): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletSemaphore, self).setUp() - - @staticmethod - def make_condition(): - return threading.Condition() - - @staticmethod - def make_event(): - return threading.Event() - - @staticmethod - def make_thread(*args, **kwargs): - return threading.Thread(*args, **kwargs) - - def _makeOne(self, *args): - return eventlet_handler.SequentialEventletHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - c = KazooClient(self.hosts, **kwargs) - try: - self._clients.append(c) - except AttributeError: - self._client = [c] - return c - - -class TestEventletLock(test_lock.KazooLockTests): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletLock, self).setUp() - - @staticmethod - def make_condition(): - return threading.Condition() - - @staticmethod - def make_event(): - return threading.Event() - - @staticmethod - def make_thread(*args, **kwargs): - return threading.Thread(*args, **kwargs) - - @staticmethod - def make_wait(): - return test_util.Wait(getsleep=(lambda: eventlet.sleep)) - - def _makeOne(self, *args): - return eventlet_handler.SequentialEventletHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - c = KazooClient(self.hosts, **kwargs) - try: - self._clients.append(c) - except AttributeError: - self._client = [c] - return c diff --git a/kazoo/tests/test_exceptions.py b/kazoo/tests/test_exceptions.py index d2fb9c6f..ff35a3fd 100644 --- a/kazoo/tests/test_exceptions.py +++ b/kazoo/tests/test_exceptions.py @@ -1,30 +1,34 @@ +from __future__ import annotations + from unittest import TestCase +from types import ModuleType + import pytest class ExceptionsTestCase(TestCase): - def _get(self): + def _get(self) -> ModuleType: from kazoo import exceptions return exceptions - def test_backwards_alias(self): + def test_backwards_alias(self) -> None: module = self._get() assert hasattr(module, "NoNodeException") assert module.NoNodeException is module.NoNodeError - def test_exceptions_code(self): + def test_exceptions_code(self) -> None: module = self._get() exc_8 = module.EXCEPTIONS[-8] assert isinstance(exc_8(), module.BadArgumentsError) - def test_invalid_code(self): + def test_invalid_code(self) -> None: module = self._get() with pytest.raises(RuntimeError): module.EXCEPTIONS.__getitem__(666) - def test_exceptions_construction(self): + def test_exceptions_construction(self) -> None: module = self._get() exc = module.EXCEPTIONS[-101]() assert type(exc) is module.NoNodeError diff --git a/kazoo/tests/test_gevent_handler.py b/kazoo/tests/test_gevent_handler.py index 28dd46b0..c0c98d92 100644 --- a/kazoo/tests/test_gevent_handler.py +++ b/kazoo/tests/test_gevent_handler.py @@ -1,61 +1,60 @@ +from __future__ import annotations + import unittest import sys +from typing import Any, Type import pytest -from kazoo.client import KazooClient from kazoo.exceptions import NoNodeError -from kazoo.protocol.states import Callback +from kazoo.handlers.utils import create_tcp_socket +from kazoo.protocol.states import Callback, ZnodeStat from kazoo.testing import KazooTestCase -from kazoo.tests import test_client + +try: + import gevent # NOQA: + from gevent.event import Event + from gevent.queue import Empty + from gevent import socket + from kazoo.handlers.gevent import AsyncResult, SequentialGeventHandler +except ImportError: + pytestmark = pytest.mark.skip(reason="gevent not available") @pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") class TestGeventHandler(unittest.TestCase): - def setUp(self): - try: - import gevent # NOQA - except ImportError: - pytest.skip("gevent not available.") - - def _makeOne(self, *args): - from kazoo.handlers.gevent import SequentialGeventHandler - + def _makeOne(self, *args: Any) -> SequentialGeventHandler: return SequentialGeventHandler(*args) - def _getAsync(self, *args): - from kazoo.handlers.gevent import AsyncResult - + def _getAsync(self) -> Type[AsyncResult[Any]]: return AsyncResult - def _getEvent(self): - from gevent.event import Event - + def _getEvent(self) -> Type[Event]: return Event - def test_proper_threading(self): + def test_proper_threading(self) -> None: h = self._makeOne() h.start() assert isinstance(h.event_object(), self._getEvent()) - def test_matching_async(self): + def test_matching_async(self) -> None: h = self._makeOne() h.start() async_handler = self._getAsync() assert isinstance(h.async_result(), async_handler) - def test_exception_raising(self): + def test_exception_raising(self) -> None: h = self._makeOne() with pytest.raises(h.timeout_exception): raise h.timeout_exception("This is a timeout") - def test_exception_in_queue(self): + def test_exception_in_queue(self) -> None: h = self._makeOne() h.start() ev = self._getEvent()() - def func(): + def func() -> None: ev.set() raise ValueError("bang") @@ -63,14 +62,12 @@ def func(): h.dispatch_callback(call1) ev.wait() - def test_queue_empty_exception(self): - from gevent.queue import Empty - + def test_queue_empty_exception(self) -> None: h = self._makeOne() h.start() ev = self._getEvent()() - def func(): + def func() -> None: ev.set() raise Empty() @@ -81,30 +78,22 @@ def func(): @pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") class TestBasicGeventClient(KazooTestCase): - def setUp(self): - try: - import gevent # NOQA - except ImportError: - pytest.skip("gevent not available.") + def setUp(self) -> None: KazooTestCase.setUp(self) - def _makeOne(self, *args): - from kazoo.handlers.gevent import SequentialGeventHandler - + def _makeOne(self, *args: Any) -> SequentialGeventHandler: return SequentialGeventHandler(*args) - def _getEvent(self): - from gevent.event import Event - + def _getEvent(self) -> Type[Event]: return Event - def test_start(self): + def test_start(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() assert client.state == "CONNECTED" client.stop() - def test_start_stop_double(self): + def test_start_stop_double(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() assert client.state == "CONNECTED" @@ -112,7 +101,7 @@ def test_start_stop_double(self): client.handler.stop() client.stop() - def test_basic_commands(self): + def test_basic_commands(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() assert client.state == "CONNECTED" @@ -122,22 +111,23 @@ def test_basic_commands(self): assert client.exists("/anode") is None client.stop() - def test_failures(self): + def test_failures(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() with pytest.raises(NoNodeError): client.get("/none") client.stop() - def test_data_watcher(self): + def test_data_watcher(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() client.ensure_path("/some/node") ev = self._getEvent()() @client.DataWatch("/some/node") - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: ev.set() + return None ev.wait() ev.clear() @@ -145,11 +135,11 @@ def changed(d, stat): ev.wait() client.stop() - def test_huge_file_descriptor(self): - import resource - from gevent import socket - from kazoo.handlers.utils import create_tcp_socket - + def test_huge_file_descriptor(self) -> None: + try: + import resource + except ImportError: + self.skipTest("resource module unavailable on this platform") try: resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096)) except (ValueError, resource.error): @@ -166,22 +156,3 @@ def test_huge_file_descriptor(self): h.stop() for sock in socks: sock.close() - - -@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") -class TestGeventClient(test_client.TestClient): - def setUp(self): - try: - import gevent # NOQA - except ImportError: - pytest.skip("gevent not available.") - KazooTestCase.setUp(self) - - def _makeOne(self, *args): - from kazoo.handlers.gevent import SequentialGeventHandler - - return SequentialGeventHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - return KazooClient(self.hosts, **kwargs) diff --git a/kazoo/tests/test_hosts.py b/kazoo/tests/test_hosts.py index a2fae1ae..80517d5d 100644 --- a/kazoo/tests/test_hosts.py +++ b/kazoo/tests/test_hosts.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from unittest import TestCase from kazoo.hosts import collect_hosts class HostsTestCase(TestCase): - def test_ipv4(self): + def test_ipv4(self) -> None: hosts, chroot = collect_hosts( "127.0.0.1:2181, 192.168.1.2:2181, \ 132.254.111.10:2181" @@ -26,7 +28,7 @@ def test_ipv4(self): ] assert chroot is None - def test_ipv6(self): + def test_ipv6(self) -> None: hosts, chroot = collect_hosts("[fe80::200:5aee:feaa:20a2]:2181") assert hosts == [("fe80::200:5aee:feaa:20a2", 2181)] assert chroot is None @@ -35,7 +37,7 @@ def test_ipv6(self): assert hosts == [("fe80::200:5aee:feaa:20a2", 2181)] assert chroot is None - def test_hosts_list(self): + def test_hosts_list(self) -> None: hosts, chroot = collect_hosts("zk01:2181, zk02:2181, zk03:2181") expected1 = [("zk01", 2181), ("zk02", 2181), ("zk03", 2181)] assert hosts == expected1 diff --git a/kazoo/tests/test_interrupt.py b/kazoo/tests/test_interrupt.py index ad4ae5d6..fd45ba1e 100644 --- a/kazoo/tests/test_interrupt.py +++ b/kazoo/tests/test_interrupt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from sys import platform @@ -7,7 +9,7 @@ class KazooInterruptTests(KazooTestCase): - def test_interrupted_systemcall(self): + def test_interrupted_systemcall(self) -> None: """ Make sure interrupted system calls don't break the world, since we can't control what all signals our connection thread will get diff --git a/kazoo/tests/test_lease.py b/kazoo/tests/test_lease.py index 98a12560..2b8cadb3 100644 --- a/kazoo/tests/test_lease.py +++ b/kazoo/tests/test_lease.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import uuid @@ -8,18 +10,18 @@ class MockClock(object): - def __init__(self, epoch=0): + def __init__(self, epoch: float = 0): self.epoch = epoch - def forward(self, seconds): + def forward(self, seconds: float) -> None: self.epoch += seconds - def __call__(self): + def __call__(self) -> datetime.datetime: return datetime.datetime.utcfromtimestamp(self.epoch) class KazooLeaseTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooLeaseTests, self).setUp() self.client2 = self._get_client(timeout=0.8) self.client2.start() @@ -28,7 +30,7 @@ def setUp(self): self.path = "/" + uuid.uuid4().hex self.clock = MockClock(10) - def tearDown(self): + def tearDown(self) -> None: for cl in [self.client2, self.client3]: if cl.connected: cl.stop() @@ -38,7 +40,7 @@ def tearDown(self): class NonBlockingLeaseTests(KazooLeaseTests): - def test_renew(self): + def test_renew(self) -> None: # Use client convenience method here to test it at least once. Use # class directly in # other tests in order to get better IDE support. @@ -54,7 +56,7 @@ def test_renew(self): ) assert renewed_lease - def test_busy(self): + def test_busy(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -74,7 +76,7 @@ def test_busy(self): assert not foreigner_lease assert foreigner_lease.obtained is False - def test_overtake(self): + def test_overtake(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -93,7 +95,7 @@ def test_overtake(self): ) assert foreigner_lease - def test_renew_no_overtake(self): + def test_renew_no_overtake(self) -> None: lease = self.client.NonBlockingLease( self.path, datetime.timedelta(seconds=3), utcnow=self.clock ) @@ -116,7 +118,7 @@ def test_renew_no_overtake(self): ) assert not foreigner_lease - def test_overtaker_renews(self): + def test_overtaker_renews(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -145,7 +147,7 @@ def test_overtaker_renews(self): ) assert foreigner_renew - def test_overtake_refuse_first(self): + def test_overtake_refuse_first(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -173,7 +175,7 @@ def test_overtake_refuse_first(self): ) assert not first_again_lease - def test_old_version(self): + def test_old_version(self) -> None: # Skip to a future version NonBlockingLease._version += 1 lease = NonBlockingLease( @@ -199,7 +201,7 @@ def test_old_version(self): class MultiNonBlockingLeaseTest(KazooLeaseTests): - def test_1_renew(self): + def test_1_renew(self) -> None: ls = self.client.MultiNonBlockingLease( 1, self.path, datetime.timedelta(seconds=4), utcnow=self.clock ) @@ -214,7 +216,7 @@ def test_1_renew(self): ) assert ls2 - def test_1_reject(self): + def test_1_reject(self) -> None: ls = MultiNonBlockingLease( self.client, 1, @@ -234,7 +236,7 @@ def test_1_reject(self): ) assert not ls2 - def test_2_renew(self): + def test_2_renew(self) -> None: ls = MultiNonBlockingLease( self.client, 2, @@ -273,7 +275,7 @@ def test_2_renew(self): ) assert ls4 - def test_2_reject(self): + def test_2_reject(self) -> None: ls = MultiNonBlockingLease( self.client, 2, @@ -303,7 +305,7 @@ def test_2_reject(self): ) assert not ls3 - def test_2_handover(self): + def test_2_handover(self) -> None: ls = MultiNonBlockingLease( self.client, 2, diff --git a/kazoo/tests/test_lock.py b/kazoo/tests/test_lock.py index 7e5ecb77..18ac4b51 100644 --- a/kazoo/tests/test_lock.py +++ b/kazoo/tests/test_lock.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import collections import threading import unittest from unittest.mock import MagicMock import uuid +from typing import Any, Callable, Deque +from types import TracebackType +from threading import Thread + import pytest from kazoo.exceptions import CancelledError from kazoo.exceptions import LockTimeout from kazoo.exceptions import NoNodeError -from kazoo.recipe.lock import Lock +from kazoo.recipe.lock import Lock, Semaphore from kazoo.testing import KazooTestCase from kazoo.tests import util as test_util @@ -17,22 +23,28 @@ class SleepBarrier(object): """A crappy spinning barrier.""" - def __init__(self, wait_for, sleep_func): + def __init__(self, wait_for: int, sleep_func: Callable[..., None]): self._wait_for = wait_for - self._arrived = collections.deque() + self._arrived: Deque[Thread] = collections.deque() self._sleep_func = sleep_func - def __enter__(self): + def __enter__(self) -> SleepBarrier: self._arrived.append(threading.current_thread()) return self - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: try: self._arrived.remove(threading.current_thread()) except ValueError: pass + return None - def wait(self): + def wait(self) -> None: while len(self._arrived) < self._wait_for: self._sleep_func(0.001) @@ -40,43 +52,45 @@ def wait(self): class KazooLockTests(KazooTestCase): thread_count = 20 - def __init__(self, *args, **kw): + def __init__(self, *args: None, **kw: None): super(KazooLockTests, self).__init__(*args, **kw) - self.threads_made = [] + self.threads_made: list[threading.Thread] = [] - def tearDown(self): + def tearDown(self) -> None: super(KazooLockTests, self).tearDown() while self.threads_made: t = self.threads_made.pop() t.join() @staticmethod - def make_condition(): + def make_condition() -> threading.Condition: return threading.Condition() @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def make_thread(self, *args, **kwargs): + def make_thread(self, *args: Any, **kwargs: Any) -> threading.Thread: t = threading.Thread(*args, **kwargs) t.daemon = True self.threads_made.append(t) return t @staticmethod - def make_wait(): + def make_wait() -> test_util.Wait: return test_util.Wait() - def setUp(self): + def setUp(self) -> None: super(KazooLockTests, self).setUp() self.lockpath = "/" + uuid.uuid4().hex self.condition = self.make_condition() self.released = self.make_event() - self.active_thread = None - self.cancelled_threads = [] + self.active_thread: str | None = None + self.cancelled_threads: list[str] = [] - def _thread_lock_acquire_til_event(self, name, lock, event): + def _thread_lock_acquire_til_event( + self, name: str, lock: Lock, event: threading.Event + ) -> None: try: with lock: with self.condition: @@ -96,7 +110,7 @@ def _thread_lock_acquire_til_event(self, name, lock, event): self.cancelled_threads.append(name) self.condition.notify_all() - def test_lock_one(self): + def test_lock_one(self) -> None: lock_name = uuid.uuid4().hex lock = self.client.Lock(self.lockpath, lock_name) event = self.make_event() @@ -127,7 +141,7 @@ def test_lock_one(self): self.released.wait() thread.join() - def test_lock(self): + def test_lock(self) -> None: threads = [] names = ["contender" + str(i) for i in range(5)] @@ -181,7 +195,7 @@ def test_lock(self): for thread in threads: thread.join() - def test_lock_reconnect(self): + def test_lock_reconnect(self) -> None: event = self.make_event() other_lock = self.client.Lock(self.lockpath, "contender") thread = self.make_thread( @@ -211,7 +225,7 @@ def test_lock_reconnect(self): event.set() thread.join() - def test_lock_non_blocking(self): + def test_lock_non_blocking(self) -> None: lock_name = uuid.uuid4().hex lock = self.client.Lock(self.lockpath, lock_name) event = self.make_event() @@ -235,7 +249,7 @@ def test_lock_non_blocking(self): event.set() thread.join() - def test_lock_fail_first_call(self): + def test_lock_fail_first_call(self) -> None: event1 = self.make_event() lock1 = self.client.Lock(self.lockpath, "one") thread1 = self.make_thread( @@ -248,12 +262,15 @@ def test_lock_fail_first_call(self): with self.condition: if not self.active_thread: self.condition.wait(5) - assert self.active_thread == "one" + assert ( + self.active_thread + == "one" # type: ignore[comparison-overlap] + ) assert lock1.contenders() == ["one"] event1.set() thread1.join() - def test_lock_cancel(self): + def test_lock_cancel(self) -> None: event1 = self.make_event() lock1 = self.client.Lock(self.lockpath, "one") thread1 = self.make_thread( @@ -266,7 +283,10 @@ def test_lock_cancel(self): with self.condition: if not self.active_thread: self.condition.wait(5) - assert self.active_thread == "one" + assert ( + self.active_thread + == "one" # type: ignore[comparison-overlap] + ) client2 = self._get_client() client2.start() @@ -296,7 +316,7 @@ def test_lock_cancel(self): thread1.join() client2.stop() - def test_lock_no_double_calls(self): + def test_lock_no_double_calls(self) -> None: lock1 = self.client.Lock(self.lockpath, "one") lock1.acquire() assert lock1.is_acquired is True @@ -305,7 +325,7 @@ def test_lock_no_double_calls(self): lock1.release() assert lock1.is_acquired is False - def test_lock_same_thread_no_block(self): + def test_lock_same_thread_no_block(self) -> None: lock = self.client.Lock(self.lockpath, "one") gotten = lock.acquire(blocking=False) assert gotten is True @@ -313,11 +333,11 @@ def test_lock_same_thread_no_block(self): gotten = lock.acquire(blocking=False) assert gotten is False - def test_lock_many_threads_no_block(self): + def test_lock_many_threads_no_block(self) -> None: lock = self.client.Lock(self.lockpath, "one") - attempts = collections.deque() + attempts: Deque[int] = collections.deque() - def _acquire(): + def _acquire() -> None: attempts.append(int(lock.acquire(blocking=False))) threads = [] @@ -332,14 +352,14 @@ def _acquire(): assert sum(list(attempts)) == 1 - def test_lock_many_threads(self): + def test_lock_many_threads(self) -> None: sleep_func = self.client.handler.sleep_func lock = self.client.Lock(self.lockpath, "one") - acquires = collections.deque() - differences = collections.deque() + acquires: Deque[int] = collections.deque() + differences: Deque[int] = collections.deque() barrier = SleepBarrier(self.thread_count, sleep_func) - def _acquire(): + def _acquire() -> None: # Wait until all threads are ready to go... with barrier as b: b.wait() @@ -366,18 +386,19 @@ def _acquire(): assert len(acquires) == self.thread_count assert list(differences) == [1] * self.thread_count - def test_lock_reacquire(self): + def test_lock_reacquire(self) -> None: lock = self.client.Lock(self.lockpath, "one") lock.acquire() lock.release() lock.acquire() lock.release() - def test_lock_ephemeral(self): + def test_lock_ephemeral(self) -> None: client1 = self._get_client() client1.start() lock = client1.Lock(self.lockpath, "ephemeral") lock.acquire(ephemeral=False) + assert lock.node is not None znode = self.lockpath + "/" + lock.node client1.stop() try: @@ -385,7 +406,7 @@ def test_lock_ephemeral(self): except NoNodeError: self.fail("NoNodeError raised unexpectedly!") - def test_lock_timeout(self): + def test_lock_timeout(self) -> None: timeout = 3 e = self.make_event() started = self.make_event() @@ -394,7 +415,9 @@ def test_lock_timeout(self): # that the main thread is going to wait to acquire the lock. lock1 = self.client.Lock(self.lockpath, "one") - def _thread(lock, event, timeout): + def _thread( + lock: threading.Lock, event: threading.Event, timeout: float + ) -> None: with lock: started.set() event.wait(timeout) @@ -426,7 +449,7 @@ def _thread(lock, event, timeout): t.join() client2.stop() - def test_read_lock(self): + def test_read_lock(self) -> None: # Test that we can obtain a read lock lock = self.client.ReadLock(self.lockpath, "reader one") gotten = lock.acquire(blocking=False) @@ -451,7 +474,7 @@ def test_read_lock(self): gotten = lock3.acquire(blocking=False) assert gotten is False - def test_write_lock(self): + def test_write_lock(self) -> None: # Test that we can obtain a write lock lock = self.client.WriteLock(self.lockpath, "writer") gotten = lock.acquire(blocking=False) @@ -468,7 +491,7 @@ def test_write_lock(self): gotten = lock2.acquire(blocking=False) assert gotten is False - def test_rw_lock(self): + def test_rw_lock(self) -> None: reader_event = self.make_event() reader_lock = self.client.ReadLock(self.lockpath, "reader") reader_thread = self.make_thread( @@ -508,8 +531,8 @@ def test_rw_lock(self): # release the lock and contenders should claim it in order lock.release() - for contender, contender_bits in contender_bits.items(): - _, event = contender_bits + for contender, bits in contender_bits.items(): + _, event = bits with self.condition: while not self.active_thread: @@ -530,44 +553,44 @@ def test_rw_lock(self): class TestSemaphore(KazooTestCase): - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any): super(TestSemaphore, self).__init__(*args, **kw) - self.threads_made = [] + self.threads_made: list[threading.Thread] = [] - def tearDown(self): + def tearDown(self) -> None: super(TestSemaphore, self).tearDown() while self.threads_made: t = self.threads_made.pop() t.join() @staticmethod - def make_condition(): + def make_condition() -> threading.Condition: return threading.Condition() @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def make_thread(self, *args, **kwargs): + def make_thread(self, *args: Any, **kwargs: Any) -> threading.Thread: t = threading.Thread(*args, **kwargs) t.daemon = True self.threads_made.append(t) return t - def setUp(self): + def setUp(self) -> None: super(TestSemaphore, self).setUp() self.lockpath = "/" + uuid.uuid4().hex self.condition = self.make_condition() self.released = self.make_event() self.active_thread = None - self.cancelled_threads = [] + self.cancelled_threads: list[str] = [] - def test_basic(self): + def test_basic(self) -> None: sem1 = self.client.Semaphore(self.lockpath) sem1.acquire() sem1.release() - def test_lock_one(self): + def test_lock_one(self) -> None: sem1 = self.client.Semaphore(self.lockpath, max_leases=1) sem2 = self.client.Semaphore(self.lockpath, max_leases=1) started = self.make_event() @@ -575,7 +598,7 @@ def test_lock_one(self): sem1.acquire() - def sema_one(): + def sema_one() -> None: started.set() with sem2: event.set() @@ -591,7 +614,7 @@ def sema_one(): assert event.is_set() is True thread.join() - def test_non_blocking(self): + def test_non_blocking(self) -> None: sem1 = self.client.Semaphore( self.lockpath, identifier="sem1", max_leases=2 ) @@ -613,7 +636,7 @@ def test_non_blocking(self): sem1.release() sem3.release() - def test_non_blocking_release(self): + def test_non_blocking_release(self) -> None: sem1 = self.client.Semaphore( self.lockpath, identifier="sem1", max_leases=1 ) @@ -627,11 +650,11 @@ def test_non_blocking_release(self): sem1.release() sem2.release() - def test_holders(self): + def test_holders(self) -> None: started = self.make_event() event = self.make_event() - def sema_one(): + def sema_one() -> None: with self.client.Semaphore(self.lockpath, "fred", max_leases=1): started.set() event.wait() @@ -645,14 +668,14 @@ def sema_one(): event.set() thread.join() - def test_semaphore_cancel(self): + def test_semaphore_cancel(self) -> None: sem1 = self.client.Semaphore(self.lockpath, "fred", max_leases=1) sem2 = self.client.Semaphore(self.lockpath, "george", max_leases=1) sem1.acquire() started = self.make_event() event = self.make_event() - def sema_one(): + def sema_one() -> None: started.set() try: with sem2: @@ -670,7 +693,7 @@ def sema_one(): assert event.is_set() thread.join() - def test_multiple_acquire_and_release(self): + def test_multiple_acquire_and_release(self) -> None: sem1 = self.client.Semaphore(self.lockpath, "fred", max_leases=1) sem1.acquire() sem1.acquire() @@ -678,7 +701,7 @@ def test_multiple_acquire_and_release(self): assert sem1.release() assert not sem1.release() - def test_handle_session_loss(self): + def test_handle_session_loss(self) -> None: expire_semaphore = self.client.Semaphore( self.lockpath, "fred", max_leases=1 ) @@ -692,7 +715,7 @@ def test_handle_session_loss(self): event = self.make_event() event2 = self.make_event() - def sema_one(): + def sema_one() -> None: started.set() with expire_semaphore: event.set() @@ -707,7 +730,7 @@ def sema_one(): # Fired in a separate thread to make sure we can see the effect expired = self.make_event() - def expire(): + def expire() -> None: self.expire_session(self.make_event) expired.set() @@ -726,7 +749,7 @@ def expire(): for t in (thread1, thread2): t.join() - def test_inconsistent_max_leases(self): + def test_inconsistent_max_leases(self) -> None: sem1 = self.client.Semaphore(self.lockpath, max_leases=1) sem2 = self.client.Semaphore(self.lockpath, max_leases=2) @@ -734,7 +757,7 @@ def test_inconsistent_max_leases(self): with pytest.raises(ValueError): sem2.acquire() - def test_inconsistent_max_leases_other_data(self): + def test_inconsistent_max_leases_other_data(self) -> None: sem1 = self.client.Semaphore(self.lockpath, max_leases=1) sem2 = self.client.Semaphore(self.lockpath, max_leases=2) @@ -745,14 +768,14 @@ def test_inconsistent_max_leases_other_data(self): # sem2 thinks it's ok to have two lease holders assert sem2.acquire(blocking=False) - def test_reacquire(self): + def test_reacquire(self) -> None: lock = self.client.Semaphore(self.lockpath) lock.acquire() lock.release() lock.acquire() lock.release() - def test_acquire_after_cancelled(self): + def test_acquire_after_cancelled(self) -> None: lock = self.client.Semaphore(self.lockpath) assert lock.acquire() is True assert lock.release() is True @@ -760,7 +783,7 @@ def test_acquire_after_cancelled(self): assert lock.cancelled is True assert lock.acquire() is True - def test_timeout(self): + def test_timeout(self) -> None: timeout = 3 e = self.make_event() started = self.make_event() @@ -769,7 +792,9 @@ def test_timeout(self): # that the main thread is going to wait to acquire the lock. sem1 = self.client.Semaphore(self.lockpath, "one") - def _thread(sem, event, timeout): + def _thread( + sem: Semaphore, event: threading.Event, timeout: float + ) -> None: with sem: started.set() event.wait(timeout) @@ -800,7 +825,7 @@ def _thread(sem, event, timeout): class TestSequence(unittest.TestCase): - def test_get_predecessor(self): + def test_get_predecessor(self) -> None: """Validate selection of predecessors.""" goLock = "_c_8eb60557ba51e0da67eefc47467d3f34-lock-0000000031" pyLock = "514e5a831836450cb1a56c741e990fd8__lock__0000000032" @@ -810,7 +835,7 @@ def test_get_predecessor(self): lock = Lock(client, "test") assert lock._get_predecessor(pyLock) is None - def test_get_predecessor_go(self): + def test_get_predecessor_go(self) -> None: """Test selection of predecessor when instructed to consider go-zk locks. """ diff --git a/kazoo/tests/test_partitioner.py b/kazoo/tests/test_partitioner.py index b2b91c8b..35076ee2 100644 --- a/kazoo/tests/test_partitioner.py +++ b/kazoo/tests/test_partitioner.py @@ -1,11 +1,15 @@ +from __future__ import annotations + import uuid import threading import time from unittest.mock import patch +from kazoo.client import KazooClient +from kazoo.interfaces import Lockable from kazoo.exceptions import LockTimeout from kazoo.testing import KazooTestCase -from kazoo.recipe.partitioner import PartitionState +from kazoo.recipe.partitioner import PartitionState, SetPartitioner class SlowLockMock: @@ -13,14 +17,19 @@ class SlowLockMock: default_delay_time = 3 - def __init__(self, client, lock, delay_time=None): + def __init__( + self, + client: KazooClient, + lock: Lockable, + delay_time: float | None = None, + ): self._client = client self._lock = lock self.delay_time = ( self.default_delay_time if delay_time is None else delay_time ) - def acquire(self, timeout=None): + def acquire(self, timeout: float | None = None) -> bool: sleep = self._client.handler.sleep_func sleep(self.delay_time) @@ -37,28 +46,32 @@ def acquire(self, timeout=None): raise LockTimeout("Mocked slow lock has timed out.") - def release(self): + def release(self) -> None: self._lock.release() +PartitionData = int +Partitioner = SetPartitioner[PartitionData] + + class KazooPartitionerTests(KazooTestCase): @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def setUp(self): + def setUp(self) -> None: super(KazooPartitionerTests, self).setUp() self.path = "/" + uuid.uuid4().hex - self.__partitioners = [] + self.__partitioners: list[Partitioner] = [] - def test_party_of_one(self): + def test_party_of_one(self) -> None: self.__create_partitioner(size=3) self.__wait_for_acquire() self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0, 1, 2]) self.__finish() - def test_party_of_two(self): + def test_party_of_two(self) -> None: for i in range(2): self.__create_partitioner(size=2, identifier=str(i)) @@ -70,7 +83,7 @@ def test_party_of_two(self): assert self.__partitioners[1].release self.__partitioners[1].finish() - def test_party_expansion(self): + def test_party_expansion(self) -> None: for i in range(2): self.__create_partitioner(size=3, identifier=str(i)) @@ -88,7 +101,7 @@ def test_party_expansion(self): self.__assert_state( PartitionState.RELEASE, partitioners=self.__partitioners[:-1] ) - for partitioner in self.__partitioners[-1]: + for partitioner in self.__partitioners[:-1]: assert partitioner.state_change_event.is_set() self.__release(self.__partitioners[:-1]) @@ -97,7 +110,7 @@ def test_party_expansion(self): self.__finish() - def test_more_members_than_set_items(self): + def test_more_members_than_set_items(self) -> None: for i in range(2): self.__create_partitioner(size=1, identifier=str(i)) @@ -107,7 +120,7 @@ def test_more_members_than_set_items(self): self.__finish() - def test_party_session_failure(self): + def test_party_session_failure(self) -> None: partitioner = self.__create_partitioner(size=3) self.__wait_for_acquire() assert partitioner.state == PartitionState.ACQUIRED @@ -116,7 +129,7 @@ def test_party_session_failure(self): partitioner.release_set() assert partitioner.failed is True - def test_connection_loss(self): + def test_connection_loss(self) -> None: self.__create_partitioner(identifier="0", size=3) self.__create_partitioner(identifier="1", size=3) @@ -146,10 +159,10 @@ def test_connection_loss(self): self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0], [1], [2]) - def test_race_condition_new_partitioner_during_the_lock(self): - locks = {} + def test_race_condition_new_partitioner_during_the_lock(self) -> None: + locks: dict[str, Lockable] = {} - def get_lock(path): + def get_lock(path: str) -> SlowLockMock: lock = locks.setdefault(path, self.client.handler.lock_object()) return SlowLockMock(self.client, lock) @@ -175,10 +188,10 @@ def get_lock(path): self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0], [1]) - def test_race_condition_new_partitioner_steals_the_lock(self): - locks = {} + def test_race_condition_new_partitioner_steals_the_lock(self) -> None: + locks: dict[str, Lockable] = {} - def get_lock(path): + def get_lock(path: str) -> SlowLockMock: new_lock = self.client.handler.lock_object() lock = locks.setdefault(path, new_lock) @@ -214,7 +227,9 @@ def get_lock(path): self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0], [1]) - def __create_partitioner(self, size, identifier=None): + def __create_partitioner( + self, size: int, identifier: str | None = None + ) -> Partitioner: partitioner = self.client.SetPartitioner( self.path, set=range(size), @@ -224,34 +239,38 @@ def __create_partitioner(self, size, identifier=None): self.__partitioners.append(partitioner) return partitioner - def __wait_for_acquire(self): + def __wait_for_acquire(self) -> None: for partitioner in self.__partitioners: partitioner.wait_for_acquire(14) - def __assert_state(self, state, partitioners=None): + def __assert_state( + self, + state: PartitionState, + partitioners: list[Partitioner] | None = None, + ) -> None: if partitioners is None: partitioners = self.__partitioners for partitioner in partitioners: assert partitioner.state == state - def __assert_partitions(self, *partitions): + def __assert_partitions(self, *partitions: list[PartitionData]) -> None: assert len(partitions) == len(self.__partitioners) for partitioner, own_partitions in zip( self.__partitioners, partitions ): assert list(partitioner) == own_partitions - def __wait(self): + def __wait(self) -> None: time.sleep(0.1) - def __release(self, partitioners=None): + def __release(self, partitioners: list[Partitioner] | None = None) -> None: if partitioners is None: partitioners = self.__partitioners for partitioner in partitioners: partitioner.release_set() - def __finish(self): + def __finish(self) -> None: for partitioner in self.__partitioners: partitioner.finish() diff --git a/kazoo/tests/test_party.py b/kazoo/tests/test_party.py index 1b32523c..f503eb33 100644 --- a/kazoo/tests/test_party.py +++ b/kazoo/tests/test_party.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import uuid from kazoo.testing import KazooTestCase class KazooPartyTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooPartyTests, self).setUp() self.path = "/" + uuid.uuid4().hex - def test_party(self): + def test_party(self) -> None: parties = [self.client.Party(self.path, "p%s" % i) for i in range(5)] one_party = parties[0] @@ -31,7 +33,7 @@ def test_party(self): assert set(party) == participants assert len(party) == len(participants) - def test_party_reuse_node(self): + def test_party_reuse_node(self) -> None: party = self.client.Party(self.path, "p1") self.client.ensure_path(self.path) self.client.create(party.create_path) @@ -39,24 +41,26 @@ def test_party_reuse_node(self): assert party.participating is True party.leave() assert party.participating is False - assert len(party) == 0 + # This appears to be an issue with mypy + assert len(party) == 0 # type: ignore[unreachable] - def test_party_vanishing_node(self): + def test_party_vanishing_node(self) -> None: party = self.client.Party(self.path, "p1") party.join() assert party.participating is True self.client.delete(party.create_path) party.leave() assert party.participating is False - assert len(party) == 0 + # This appears to be an issue with mypy + assert len(party) == 0 # type: ignore[unreachable] class KazooShallowPartyTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooShallowPartyTests, self).setUp() self.path = "/" + uuid.uuid4().hex - def test_party(self): + def test_party(self) -> None: parties = [ self.client.ShallowParty(self.path, "p%s" % i) for i in range(5) ] diff --git a/kazoo/tests/test_paths.py b/kazoo/tests/test_paths.py index 438c2eca..a8064d48 100644 --- a/kazoo/tests/test_paths.py +++ b/kazoo/tests/test_paths.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase import pytest @@ -6,67 +8,67 @@ class NormPathTestCase(TestCase): - def test_normpath(self): + def test_normpath(self) -> None: assert paths.normpath("/a/b") == "/a/b" - def test_normpath_empty(self): + def test_normpath_empty(self) -> None: assert paths.normpath("") == "" - def test_normpath_unicode(self): + def test_normpath_unicode(self) -> None: assert paths.normpath("/\xe4/b") == "/\xe4/b" - def test_normpath_dots(self): + def test_normpath_dots(self) -> None: assert paths.normpath("/a./b../c") == "/a./b../c" - def test_normpath_slash(self): + def test_normpath_slash(self) -> None: assert paths.normpath("/") == "/" - def test_normpath_multiple_slashes(self): + def test_normpath_multiple_slashes(self) -> None: assert paths.normpath("//") == "/" assert paths.normpath("//a/b") == "/a/b" assert paths.normpath("/a//b//") == "/a/b" assert paths.normpath("//a////b///c/") == "/a/b/c" - def test_normpath_relative(self): + def test_normpath_relative(self) -> None: with pytest.raises(ValueError): paths.normpath("./a/b") with pytest.raises(ValueError): paths.normpath("/a/../b") - def test_normpath_trailing(self): + def test_normpath_trailing(self) -> None: assert paths.normpath("/", trailing=True) == "/" class JoinTestCase(TestCase): - def test_join(self): + def test_join(self) -> None: assert paths.join("/a") == "/a" assert paths.join("/a", "b/") == "/a/b/" assert paths.join("/a", "b", "c") == "/a/b/c" - def test_join_empty(self): + def test_join_empty(self) -> None: assert paths.join("") == "" assert paths.join("", "a", "b") == "a/b" assert paths.join("/a", "", "b/", "c") == "/a/b/c" - def test_join_absolute(self): + def test_join_absolute(self) -> None: assert paths.join("/a/b", "/c") == "/c" class IsAbsTestCase(TestCase): - def test_isabs(self): + def test_isabs(self) -> None: assert paths.isabs("/") is True assert paths.isabs("/a") is True assert paths.isabs("/a//b/c") is True assert paths.isabs("//a/b") is True - def test_isabs_false(self): + def test_isabs_false(self) -> None: assert paths.isabs("") is False assert paths.isabs("a/") is False assert paths.isabs("a/../") is False class BaseNameTestCase(TestCase): - def test_basename(self): + def test_basename(self) -> None: assert paths.basename("") == "" assert paths.basename("/") == "" assert paths.basename("//a") == "a" @@ -75,7 +77,7 @@ def test_basename(self): class PrefixRootTestCase(TestCase): - def test_prefix_root(self): + def test_prefix_root(self) -> None: assert paths._prefix_root("/a/", "b/c") == "/a/b/c" assert paths._prefix_root("/a/b", "c/d") == "/a/b/c/d" assert paths._prefix_root("/a", "/b/c") == "/a/b/c" @@ -83,7 +85,7 @@ def test_prefix_root(self): class NormRootTestCase(TestCase): - def test_norm_root(self): + def test_norm_root(self) -> None: assert paths._norm_root("") == "/" assert paths._norm_root("/") == "/" assert paths._norm_root("//a") == "/a" diff --git a/kazoo/tests/test_queue.py b/kazoo/tests/test_queue.py index b4b8455f..0d86f57a 100644 --- a/kazoo/tests/test_queue.py +++ b/kazoo/tests/test_queue.py @@ -1,36 +1,42 @@ +from __future__ import annotations + import uuid +from typing import Any + import pytest +from kazoo.interfaces import Event +from kazoo.recipe.queue import LockingQueue, Queue from kazoo.testing import KazooTestCase from kazoo.tests.util import CI_ZK_VERSION class KazooQueueTests(KazooTestCase): - def _makeOne(self): + def _makeOne(self) -> Queue: path = "/" + uuid.uuid4().hex return self.client.Queue(path) - def test_queue_validation(self): + def test_queue_validation(self) -> None: queue = self._makeOne() with pytest.raises(TypeError): - queue.put({}) + queue.put({}) # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", b"100") + queue.put(b"one", b"100") # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", 10.0) + queue.put(b"one", 10.0) # type: ignore[arg-type] with pytest.raises(ValueError): queue.put(b"one", -100) with pytest.raises(ValueError): queue.put(b"one", 100000) - def test_empty_queue(self): + def test_empty_queue(self) -> None: queue = self._makeOne() assert len(queue) == 0 assert queue.get() is None assert len(queue) == 0 - def test_queue(self): + def test_queue(self) -> None: queue = self._makeOne() queue.put(b"one") queue.put(b"two") @@ -42,7 +48,7 @@ def test_queue(self): assert queue.get() == b"three" assert len(queue) == 0 - def test_priority(self): + def test_priority(self) -> None: queue = self._makeOne() queue.put(b"four", priority=101) queue.put(b"one", priority=0) @@ -56,7 +62,7 @@ def test_priority(self): class KazooLockingQueueTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) skip = False if CI_ZK_VERSION and CI_ZK_VERSION < (3, 4): @@ -70,42 +76,42 @@ def setUp(self): if skip: pytest.skip("Must use Zookeeper 3.4 or above") - def _makeOne(self): + def _makeOne(self) -> LockingQueue: path = "/" + uuid.uuid4().hex return self.client.LockingQueue(path) - def test_queue_validation(self): + def test_queue_validation(self) -> None: queue = self._makeOne() with pytest.raises(TypeError): - queue.put({}) + queue.put({}) # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", b"100") + queue.put(b"one", b"100") # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", 10.0) + queue.put(b"one", 10.0) # type: ignore[arg-type] with pytest.raises(ValueError): queue.put(b"one", -100) with pytest.raises(ValueError): queue.put(b"one", 100000) with pytest.raises(TypeError): - queue.put_all({}) + queue.put_all({}) # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put_all([{}]) + queue.put_all([{}]) # type: ignore[list-item] with pytest.raises(TypeError): - queue.put_all([b"one"], b"100") + queue.put_all([b"one"], b"100") # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put_all([b"one"], 10.0) + queue.put_all([b"one"], 10.0) # type: ignore[arg-type] with pytest.raises(ValueError): queue.put_all([b"one"], -100) with pytest.raises(ValueError): queue.put_all([b"one"], 100000) - def test_empty_queue(self): + def test_empty_queue(self) -> None: queue = self._makeOne() assert len(queue) == 0 assert queue.get(0) is None assert len(queue) == 0 - def test_queue(self): + def test_queue(self) -> None: queue = self._makeOne() queue.put(b"one") queue.put_all([b"two", b"three"]) @@ -130,7 +136,7 @@ def test_queue(self): assert not queue.consume() assert len(queue) == 0 - def test_consume(self): + def test_consume(self) -> None: queue = self._makeOne() queue.put(b"one") @@ -139,7 +145,7 @@ def test_consume(self): assert queue.consume() assert not queue.consume() - def test_release(self): + def test_release(self) -> None: queue = self._makeOne() queue.put(b"one") @@ -152,7 +158,7 @@ def test_release(self): assert not queue.release() assert len(queue) == 0 - def test_holds_lock(self): + def test_holds_lock(self) -> None: queue = self._makeOne() assert not queue.holds_lock() @@ -162,7 +168,7 @@ def test_holds_lock(self): queue.consume() assert not queue.holds_lock() - def test_priority(self): + def test_priority(self) -> None: queue = self._makeOne() queue.put(b"four", priority=101) queue.put(b"one", priority=0) @@ -178,16 +184,16 @@ def test_priority(self): assert queue.get(1) == b"four" assert queue.consume() - def test_concurrent_execution(self): + def test_concurrent_execution(self) -> None: queue = self._makeOne() - value1 = [] - value2 = [] - value3 = [] + value1: list[bytes | None] = [] + value2: list[bytes | None] = [] + value3: list[bytes | None] = [] event1 = self.client.handler.event_object() event2 = self.client.handler.event_object() event3 = self.client.handler.event_object() - def get_concurrently(value, event): + def get_concurrently(value: list[Any], event: Event) -> None: q = self.client.LockingQueue(queue.path) value.append(q.get(0.1)) event.set() diff --git a/kazoo/tests/test_retry.py b/kazoo/tests/test_retry.py index acbe5bd9..1c3f9282 100644 --- a/kazoo/tests/test_retry.py +++ b/kazoo/tests/test_retry.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Any from unittest import mock import pytest @@ -6,16 +9,19 @@ from kazoo import retry as kr -def _make_retry(*args, **kwargs): +def _make_retry(*args: Any, **kwargs: Any) -> kr.KazooRetry: """Return a KazooRetry instance with a dummy sleep function.""" - def _sleep_func(_time): + def _sleep_func(_time: float) -> None: pass - return kr.KazooRetry(*args, sleep_func=_sleep_func, **kwargs) + # FIXME better way of doing this? Use TypedDict perhaps? + return kr.KazooRetry( + *args, sleep_func=_sleep_func, **kwargs # type: ignore[misc] + ) -def _make_try_func(times=1): +def _make_try_func(times: int = 1) -> mock.Mock: """Returns a function that raises ForceRetryError `times` time before returning None. """ @@ -25,7 +31,7 @@ def _make_try_func(times=1): return callmock -def test_call(): +def test_call() -> None: retry = _make_retry(delay=0, max_tries=2) func = _make_try_func() retry(func, "foo", bar="baz") @@ -35,7 +41,7 @@ def test_call(): ] -def test_reset(): +def test_reset() -> None: retry = _make_retry(delay=0, max_tries=2) func = _make_try_func() retry(func) @@ -46,7 +52,7 @@ def test_reset(): assert retry._attempts == 0 -def test_too_many_tries(): +def test_too_many_tries() -> None: retry = _make_retry(delay=0, max_tries=10) func = _make_try_func(times=999) with pytest.raises(kr.RetryFailedError): @@ -56,7 +62,7 @@ def test_too_many_tries(): ), "Called 10 times, failed _attempts 10" -def test_maximum_delay(): +def test_maximum_delay() -> None: retry = _make_retry(delay=10, max_tries=100, max_jitter=0) func = _make_try_func(times=2) retry(func) @@ -71,27 +77,27 @@ def test_maximum_delay(): assert isinstance(retry._cur_delay, float) -def test_copy(): +def test_copy() -> None: retry = _make_retry() rcopy = retry.copy() assert rcopy is not retry assert rcopy.sleep_func is retry.sleep_func -def test_connection_closed(): +def test_connection_closed() -> None: retry = _make_retry() - def testit(): + def testit() -> None: raise ke.ConnectionClosedError with pytest.raises(ke.ConnectionClosedError): retry(testit) -def test_session_expired(): +def test_session_expired() -> None: retry = _make_retry(max_tries=1) - def testit(): + def testit() -> None: raise ke.SessionExpiredError with pytest.raises(kr.RetryFailedError): diff --git a/kazoo/tests/test_sasl.py b/kazoo/tests/test_sasl.py index 6daa2a0b..c70806fb 100644 --- a/kazoo/tests/test_sasl.py +++ b/kazoo/tests/test_sasl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess import time @@ -13,7 +15,7 @@ class TestLegacySASLDigestAuthentication(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: try: import puresasl # NOQA except ImportError: @@ -29,11 +31,11 @@ def setUp(self): if not version or version < (3, 4): pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() os.environ.pop("ZOOKEEPER_JAAS_AUTH", None) - def test_connect_sasl_auth(self): + def test_connect_sasl_auth(self) -> None: from kazoo.security import make_acl username = "jaasuser" @@ -56,14 +58,14 @@ def test_connect_sasl_auth(self): client.stop() client.close() - def test_invalid_sasl_auth(self): + def test_invalid_sasl_auth(self) -> None: client = self._get_client(auth_data=[("sasl", "baduser:badpassword")]) with pytest.raises(AuthFailedError): client.start() class TestSASLDigestAuthentication(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: try: import puresasl # NOQA except ImportError: @@ -79,11 +81,11 @@ def setUp(self): if not version or version < (3, 4): pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() os.environ.pop("ZOOKEEPER_JAAS_AUTH", None) - def test_connect_sasl_auth(self): + def test_connect_sasl_auth(self) -> None: from kazoo.security import make_acl username = "jaasuser" @@ -110,7 +112,7 @@ def test_connect_sasl_auth(self): client.stop() client.close() - def test_invalid_sasl_auth(self): + def test_invalid_sasl_auth(self) -> None: client = self._get_client( sasl_options={ "mechanism": "DIGEST-MD5", @@ -123,13 +125,17 @@ def test_invalid_sasl_auth(self): class TestSASLGSSAPIAuthentication(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: + # FIXME Under what circumstances is it ok for these to not be + # available? pip install is a thing. try: - import puresasl # NOQA + import puresasl except ImportError: pytest.skip("PureSASL not available.") try: - import kerberos # NOQA + # FIXME Hound objects to import not found as it thinks it's a + # syntax error. I don't know why it thinks that. + import kerberos # type: ignore except ImportError: pytest.skip("Kerberos support not available.") if not os.environ.get("KRB5_TEST_ENV"): @@ -145,11 +151,11 @@ def setUp(self): if not version or version < (3, 4): pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() os.environ.pop("ZOOKEEPER_JAAS_AUTH", None) - def test_connect_gssapi_auth(self): + def test_connect_gssapi_auth(self) -> None: from kazoo.security import make_acl # Ensure we have a client ticket @@ -177,7 +183,7 @@ def test_connect_gssapi_auth(self): client.stop() client.close() - def test_invalid_gssapi_auth(self): + def test_invalid_gssapi_auth(self) -> None: # Request a post-datated ticket, so that it is currently invalid. subprocess.check_call( [ diff --git a/kazoo/tests/test_security.py b/kazoo/tests/test_security.py index bc45483a..ba21862b 100644 --- a/kazoo/tests/test_security.py +++ b/kazoo/tests/test_security.py @@ -1,19 +1,23 @@ +from __future__ import annotations + import unittest -from kazoo.security import Permissions +from typing import Any + +from kazoo.security import ACL, Permissions class TestACL(unittest.TestCase): - def _makeOne(self, *args, **kwargs): + def _makeOne(self, *args: Any, **kwargs: Any) -> ACL: from kazoo.security import make_acl return make_acl(*args, **kwargs) - def test_read_acl(self): + def test_read_acl(self) -> None: acl = self._makeOne("digest", ":", read=True) - assert acl.perms & Permissions.READ == Permissions.READ + assert (acl.perms & Permissions.READ) == Permissions.READ - def test_all_perms(self): + def test_all_perms(self) -> None: acl = self._makeOne( "digest", ":", @@ -30,25 +34,28 @@ def test_all_perms(self): Permissions.DELETE, Permissions.ADMIN, ]: - assert acl.perms & perm == perm + assert (acl.perms & perm) == perm + + def test_perm_listing(self) -> None: + # FIXME ACL(n, Id) isn't an API, so why do we do this? - def test_perm_listing(self): - from kazoo.security import ACL + from kazoo.security import Id - f = ACL(15, "fred") + f = ACL(15, Id("fred", "bill")) assert "READ" in f.acl_list assert "WRITE" in f.acl_list assert "CREATE" in f.acl_list assert "DELETE" in f.acl_list - f = ACL(16, "fred") + f = ACL(16, Id("fred", "bill")) assert "ADMIN" in f.acl_list - f = ACL(31, "george") + f = ACL(31, Id("fred", "bill")) assert "ALL" in f.acl_list - def test_perm_repr(self): - from kazoo.security import ACL + def test_perm_repr(self) -> None: + # FIXME See above + from kazoo.security import Id - f = ACL(16, "fred") + f = ACL(16, Id("fred", "bill")) assert "ACL(perms=16, acl_list=['ADMIN']" in repr(f) diff --git a/kazoo/tests/test_selectors_select.py b/kazoo/tests/test_selectors_select.py index 99dd44ae..8d367b17 100644 --- a/kazoo/tests/test_selectors_select.py +++ b/kazoo/tests/test_selectors_select.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ The official python select function test case copied from python source to test the selector_select function. @@ -8,7 +10,10 @@ import sys import unittest +from typing import cast + from kazoo.handlers.utils import selector_select +from kazoo.interfaces import HasFileNo select = selector_select @@ -21,10 +26,10 @@ class Nope: pass class Almost: - def fileno(self): + def fileno(self) -> str: return "fileno" - def test_error_conditions(self): + def test_error_conditions(self) -> None: self.assertRaises(TypeError, select, 1, 2, 3) self.assertRaises(TypeError, select, [self.Nope()], [], []) self.assertRaises(TypeError, select, [self.Almost()], [], []) @@ -36,41 +41,44 @@ def test_error_conditions(self): sys.platform.startswith("freebsd"), "skip because of a FreeBSD bug: kern/155606", ) - def test_errno(self): + def test_errno(self) -> None: with open(__file__, "rb") as fp: fd = fp.fileno() fp.close() self.assertRaises(ValueError, select, [fd], [], [], 0) - def test_returned_list_identity(self): + def test_returned_list_identity(self) -> None: # See issue #8329 r, w, x = select([], [], [], 1) self.assertIsNot(r, w) self.assertIsNot(r, x) self.assertIsNot(w, x) - def test_select(self): + def test_select(self) -> None: cmd = "for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done" p = os.popen(cmd, "r") for tout in (0, 1, 2, 4, 8, 16) + (None,) * 10: - rfd, wfd, xfd = select([p], [], [], tout) + rfd, wfd, xfd = select([cast("HasFileNo", p)], [], [], tout) if (rfd, wfd, xfd) == ([], [], []): continue - if (rfd, wfd, xfd) == ([p], [], []): + if (rfd, wfd, xfd) == ([cast("HasFileNo", p)], [], []): line = p.readline() if not line: break continue - self.fail("Unexpected return values from select():", rfd, wfd, xfd) + self.fail( + "Unexpected return values from select(): %s %s %s" + % (rfd, wfd, xfd) + ) p.close() # Issue 16230: Crash on select resized list - def test_select_mutated(self): + def test_select_mutated(self) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - a = [] + a: list[HasFileNo] = [] class F: - def fileno(self): + def fileno(self) -> int: del a[-1] return s.fileno() diff --git a/kazoo/tests/test_threading_handler.py b/kazoo/tests/test_threading_handler.py index eac4ea9f..f4ac5d43 100644 --- a/kazoo/tests/test_threading_handler.py +++ b/kazoo/tests/test_threading_handler.py @@ -1,41 +1,45 @@ +from __future__ import annotations + +import socket import threading import unittest + +from typing import Any, Type from unittest.mock import Mock import pytest +from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler +from kazoo.handlers.utils import create_tcp_socket -class TestThreadingHandler(unittest.TestCase): - def _makeOne(self, *args): - from kazoo.handlers.threading import SequentialThreadingHandler +class TestThreadingHandler(unittest.TestCase): + def _makeOne(self, *args: Any) -> SequentialThreadingHandler: return SequentialThreadingHandler(*args) - def _getAsync(self, *args): - from kazoo.handlers.threading import AsyncResult - + def _getAsync(self) -> Type[AsyncResult]: return AsyncResult - def test_proper_threading(self): + def test_proper_threading(self) -> None: h = self._makeOne() h.start() # In Python 3.3 _Event is gone, before Event is function event_class = getattr(threading, "_Event", threading.Event) assert isinstance(h.event_object(), event_class) - def test_matching_async(self): + def test_matching_async(self) -> None: h = self._makeOne() h.start() async_result = self._getAsync() assert isinstance(h.async_result(), async_result) - def test_exception_raising(self): + def test_exception_raising(self) -> None: h = self._makeOne() with pytest.raises(h.timeout_exception): raise h.timeout_exception("This is a timeout") - def test_double_start_stop(self): + def test_double_start_stop(self) -> None: h = self._makeOne() h.start() assert h._running is True @@ -44,14 +48,11 @@ def test_double_start_stop(self): h.stop() assert h._running is False - def test_huge_file_descriptor(self): + def test_huge_file_descriptor(self) -> None: try: import resource except ImportError: self.skipTest("resource module unavailable on this platform") - import socket - from kazoo.handlers.utils import create_tcp_socket - try: resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096)) except (ValueError, resource.error): @@ -71,17 +72,13 @@ def test_huge_file_descriptor(self): class TestThreadingAsync(unittest.TestCase): - def _makeOne(self, *args): - from kazoo.handlers.threading import AsyncResult - + def _makeOne(self, *args: Any) -> AsyncResult: return AsyncResult(*args) - def _makeHandler(self): - from kazoo.handlers.threading import SequentialThreadingHandler - + def _makeHandler(self) -> SequentialThreadingHandler: return SequentialThreadingHandler() - def test_ready(self): + def test_ready(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -91,7 +88,7 @@ def test_ready(self): assert async_result.successful() is True assert async_result.exception is None - def test_callback_queued(self): + def test_callback_queued(self) -> None: mock_handler = Mock() mock_handler.completion_queue = Mock() async_result = self._makeOne(mock_handler) @@ -101,7 +98,7 @@ def test_callback_queued(self): assert mock_handler.completion_queue.put.called - def test_set_exception(self): + def test_set_exception(self) -> None: mock_handler = Mock() mock_handler.completion_queue = Mock() async_result = self._makeOne(mock_handler) @@ -111,7 +108,7 @@ def test_set_exception(self): assert isinstance(async_result.exception, ImportError) assert mock_handler.completion_queue.put.called - def test_get_wait_while_setting(self): + def test_get_wait_while_setting(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -119,7 +116,7 @@ def test_get_wait_while_setting(self): bv = threading.Event() cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: bv.set() val = async_result.get() lst.append(val) @@ -134,7 +131,7 @@ def wait_for_val(): assert lst == ["fred"] th.join() - def test_get_with_nowait(self): + def test_get_with_nowait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) timeout = self._makeHandler().timeout_exception @@ -145,7 +142,7 @@ def test_get_with_nowait(self): with pytest.raises(timeout): async_result.get_nowait() - def test_get_with_exception(self): + def test_get_with_exception(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -153,7 +150,7 @@ def test_get_with_exception(self): bv = threading.Event() cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: bv.set() try: val = async_result.get() @@ -167,20 +164,20 @@ def wait_for_val(): th.start() bv.wait() - async_result.set_exception(ImportError) + async_result.set_exception(ImportError()) cv.wait() assert lst == ["oops"] th.join() - def test_wait(self): + def test_wait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) - lst = [] + lst: list[bool | str] = [] bv = threading.Event() cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: bv.set() try: val = async_result.wait(10) @@ -199,7 +196,7 @@ def wait_for_val(): assert lst == [True] th.join() - def test_wait_race(self): + def test_wait_race(self) -> None: """Test that there is no race condition in `IAsyncResult.wait()`. Guards against the reappearance of: @@ -212,7 +209,7 @@ def test_wait_race(self): cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: # NB: should not sleep async_result.wait(20) cv.set() @@ -227,7 +224,7 @@ def wait_for_val(): assert cv.is_set() is True th.join() - def test_set_before_wait(self): + def test_set_before_wait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -235,7 +232,7 @@ def test_set_before_wait(self): cv = threading.Event() async_result.set("fred") - def wait_for_val(): + def wait_for_val() -> None: val = async_result.get() lst.append(val) cv.set() @@ -246,15 +243,15 @@ def wait_for_val(): assert lst == ["fred"] th.join() - def test_set_exc_before_wait(self): + def test_set_exc_before_wait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) lst = [] cv = threading.Event() - async_result.set_exception(ImportError) + async_result.set_exception(ImportError()) - def wait_for_val(): + def wait_for_val() -> None: try: val = async_result.get() except ImportError: @@ -269,17 +266,17 @@ def wait_for_val(): assert lst == ["ooops"] th.join() - def test_linkage(self): + def test_linkage(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) cv = threading.Event() lst = [] - def add_on(): + def add_on() -> None: lst.append(True) - def wait_for_val(): + def wait_for_val() -> None: async_result.get() cv.set() @@ -294,13 +291,13 @@ def wait_for_val(): assert async_result.value == b"fred" th.join() - def test_linkage_not_ready(self): + def test_linkage_not_ready(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) lst = [] - def add_on(): + def add_on() -> None: lst.append(True) async_result.set("fred") @@ -308,13 +305,13 @@ def add_on(): async_result.rawlink(add_on) assert mock_handler.completion_queue.put.called - def test_link_and_unlink(self): + def test_link_and_unlink(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) lst = [] - def add_on(): + def add_on() -> None: lst.append(True) async_result.rawlink(add_on) @@ -323,14 +320,14 @@ def add_on(): async_result.set("fred") assert not mock_handler.completion_queue.put.called - def test_captured_exception(self): + def test_captured_exception(self) -> None: from kazoo.handlers.utils import capture_exceptions mock_handler = Mock() async_result = self._makeOne(mock_handler) @capture_exceptions(async_result) - def exceptional_function(): + def exceptional_function() -> float: return 1 / 0 exceptional_function() @@ -338,7 +335,7 @@ def exceptional_function(): with pytest.raises(ZeroDivisionError): async_result.get() - def test_no_capture_exceptions(self): + def test_no_capture_exceptions(self) -> None: from kazoo.handlers.utils import capture_exceptions mock_handler = Mock() @@ -346,20 +343,20 @@ def test_no_capture_exceptions(self): lst = [] - def add_on(): + def add_on() -> None: lst.append(True) async_result.rawlink(add_on) @capture_exceptions(async_result) - def regular_function(): + def regular_function() -> bool: return True regular_function() assert not mock_handler.completion_queue.put.called - def test_wraps(self): + def test_wraps(self) -> None: from kazoo.handlers.utils import wrap mock_handler = Mock() @@ -367,20 +364,20 @@ def test_wraps(self): lst = [] - def add_on(result): + def add_on(result: AsyncResult) -> None: lst.append(result.get()) async_result.rawlink(add_on) @wrap(async_result) - def regular_function(): + def regular_function() -> str: return "hello" assert regular_function() == "hello" assert mock_handler.completion_queue.put.called assert async_result.get() == "hello" - def test_multiple_callbacks(self): + def test_multiple_callbacks(self) -> None: mockback1 = Mock(name="mockback1") mockback2 = Mock(name="mockback2") handler = self._makeHandler() diff --git a/kazoo/tests/test_utils.py b/kazoo/tests/test_utils.py index 96d484a8..b686cabd 100644 --- a/kazoo/tests/test_utils.py +++ b/kazoo/tests/test_utils.py @@ -1,10 +1,18 @@ +from __future__ import annotations + +import ssl +import socket +import time +from kazoo.handlers import utils +from kazoo.handlers.utils import create_tcp_connection + import unittest from unittest.mock import patch import pytest try: - from kazoo.handlers.eventlet import green_socket as socket + from eventlet.green import socket as green_socket EVENTLET_HANDLER_AVAILABLE = True except ImportError: @@ -12,10 +20,7 @@ class TestCreateTCPConnection(unittest.TestCase): - def test_timeout_arg(self): - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, socket, time - + def test_timeout_arg(self) -> None: with patch.object(socket, "create_connection") as create_connection: with patch.object(utils, "_set_default_tcpsock_options"): # Ensure a gap between calls to time.time() does not result in @@ -30,10 +35,7 @@ def test_timeout_arg(self): timeout = call_args[0][1] assert timeout >= 0, "socket timeout must be nonnegative" - def test_ssl_server_hostname(self): - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, socket, ssl - + def test_ssl_server_hostname(self) -> None: with patch.object(utils, "_set_default_tcpsock_options"): with patch.object(ssl.SSLContext, "wrap_socket") as wrap_socket: create_tcp_connection( @@ -48,10 +50,7 @@ def test_ssl_server_hostname(self): server_hostname = call_args[1]["server_hostname"] assert server_hostname == "fakehostname" - def test_ssl_server_check_hostname(self): - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, socket, ssl - + def test_ssl_server_check_hostname(self) -> None: with patch.object(utils, "_set_default_tcpsock_options"): with patch.object( ssl.SSLContext, "wrap_socket", autospec=True @@ -69,9 +68,7 @@ def test_ssl_server_check_hostname(self): ssl_context = call_args[0][0] assert ssl_context.check_hostname - def test_ssl_server_check_hostname_config_validation(self): - from kazoo.handlers.utils import create_tcp_connection, socket - + def test_ssl_server_check_hostname_config_validation(self) -> None: with pytest.raises(ValueError): create_tcp_connection( socket, @@ -83,51 +80,45 @@ def test_ssl_server_check_hostname_config_validation(self): check_hostname=True, ) - def test_timeout_arg_eventlet(self): + def test_timeout_arg_eventlet(self) -> None: if not EVENTLET_HANDLER_AVAILABLE: pytest.skip("eventlet handler not available.") - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, time - - with patch.object(socket, "create_connection") as create_connection: + with patch.object( + green_socket, "create_connection" + ) as create_connection: with patch.object(utils, "_set_default_tcpsock_options"): # Ensure a gap between calls to time.time() does not result in # create_connection being called with a negative timeout # argument. with patch.object(time, "time", side_effect=range(10)): create_tcp_connection( - socket, ("127.0.0.1", 2181), timeout=1.5 + green_socket, ("127.0.0.1", 2181), timeout=1.5 ) for call_args in create_connection.call_args_list: timeout = call_args[0][1] assert timeout >= 0, "socket timeout must be nonnegative" - def test_slow_connect(self): + def test_slow_connect(self) -> None: # Currently, create_tcp_connection will raise a socket timeout if it # takes longer than the specified "timeout" to create a connection. # In the future, "timeout" might affect only the created socket and not # the time it takes to create it. - from kazoo.handlers.utils import create_tcp_connection, socket, time - # Simulate a second passing between calls to check the current time. with patch.object(time, "time", side_effect=range(10)): with pytest.raises(socket.error): create_tcp_connection(socket, ("127.0.0.1", 2181), timeout=0.5) - def test_negative_timeout(self): - from kazoo.handlers.utils import create_tcp_connection, socket - + def test_negative_timeout(self) -> None: with pytest.raises(socket.error): create_tcp_connection(socket, ("127.0.0.1", 2181), timeout=-1) - def test_zero_timeout(self): + def test_zero_timeout(self) -> None: # Rather than pass '0' through as a timeout to # socket.create_connection, create_tcp_connection should raise # socket.error. This is because the socket library treats '0' as an # indicator to create a non-blocking socket. - from kazoo.handlers.utils import create_tcp_connection, socket, time # Simulate no time passing between calls to check the current time. with patch.object(time, "time", return_value=time.time()): diff --git a/kazoo/tests/test_watchers.py b/kazoo/tests/test_watchers.py index dc36ef61..e0852a18 100644 --- a/kazoo/tests/test_watchers.py +++ b/kazoo/tests/test_watchers.py @@ -1,32 +1,39 @@ +from __future__ import annotations + import time import threading import uuid +from typing import Any, List, Literal + import pytest from kazoo.exceptions import KazooException -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent, ZnodeStat +from kazoo.recipe.watchers import PatientChildrenWatch + from kazoo.testing import KazooTestCase class KazooDataWatcherTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooDataWatcherTests, self).setUp() self.path = "/" + uuid.uuid4().hex self.client.ensure_path(self.path) - def test_data_watcher(self): + def test_data_watcher(self) -> None: update = threading.Event() - data = [True] + data: list[bool | bytes | None] = [True] # Make it a non-existent path self.path += "f" @self.client.DataWatch(self.path) - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() data.append(d) update.set() + return None update.wait(10) assert data == [None] @@ -37,9 +44,9 @@ def changed(d, stat): assert data[0] == b"fred" update.clear() - def test_data_watcher_once(self): + def test_data_watcher_once(self) -> None: update = threading.Event() - data = [True] + data: list[bool | bytes | None] = [True] # Make it a non-existent path self.path += "f" @@ -47,10 +54,11 @@ def test_data_watcher_once(self): dwatcher = self.client.DataWatch(self.path) @dwatcher - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() data.append(d) update.set() + return None update.wait(10) assert data == [None] @@ -59,23 +67,27 @@ def changed(d, stat): with pytest.raises(KazooException): @dwatcher - def func(d, stat): + def func(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() + return None - def test_data_watcher_with_event(self): + def test_data_watcher_with_event(self) -> None: # Test that the data watcher gets passed the event, if it # accepts three arguments update = threading.Event() - data = [True] + data: list[Literal[True] | WatchedEvent | None] = [True] # Make it a non-existent path self.path += "f" @self.client.DataWatch(self.path) - def changed(d, stat, event): + def changed( + d: bytes | None, stat: ZnodeStat | None, event: WatchedEvent | None + ) -> bool | None: data.pop() data.append(event) update.set() + return None update.wait(10) assert data == [None] @@ -83,17 +95,18 @@ def changed(d, stat, event): self.client.create(self.path, b"fred") update.wait(10) + assert data[0] is not None and data[0] is not True assert data[0].type == EventType.CREATED update.clear() - def test_func_style_data_watch(self): + def test_func_style_data_watch(self) -> None: update = threading.Event() - data = [True] + data: list[bytes | None | Literal[True]] = [True] # Make it a non-existent path path = self.path + "f" - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> None: data.pop() data.append(d) update.set() @@ -109,12 +122,12 @@ def changed(d, stat): assert data[0] == b"fred" update.clear() - def test_datawatch_across_session_expire(self): + def test_datawatch_across_session_expire(self) -> None: update = threading.Event() - data = [True] + data: list[bytes | None | Literal[True]] = [True] @self.client.DataWatch(self.path) - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> None: data.pop() data.append(d) update.set() @@ -128,21 +141,22 @@ def changed(d, stat): update.wait(25) assert data[0] == b"fred" - def test_func_stops(self): + def test_func_stops(self) -> None: update = threading.Event() - data = [True] + data: list[bytes | None | Literal[True]] = [True] self.path += "f" - fail_through = [] + fail_through: list[bool] = [] @self.client.DataWatch(self.path) - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() data.append(d) update.set() if fail_through: return False + return None update.wait(10) assert data == [None] @@ -161,21 +175,21 @@ def changed(d, stat): d, stat = self.client.get(self.path) assert d == b"asdfasdf" - def test_no_such_node(self): + def test_no_such_node(self) -> None: args = [] @self.client.DataWatch("/some/path") - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> None: args.extend([d, stat]) assert args == [None, None] - def test_no_such_node_for_children_watch(self): + def test_no_such_node_for_children_watch(self) -> None: args = [] path = self.path + "/test_no_such_node_for_children_watch" update = threading.Event() - def changed(children): + def changed(children: list[str] | None) -> None: args.append(children) update.set() @@ -218,9 +232,9 @@ def changed(children): assert update.is_set() is False assert children_watch._stopped is True - def test_watcher_evaluating_to_false(self): - class WeirdWatcher(list): - def __call__(self, *args): + def test_watcher_evaluating_to_false(self) -> None: + class WeirdWatcher(List[Any]): + def __call__(self, *args: Any) -> None: self.called = True watcher = WeirdWatcher() @@ -228,14 +242,14 @@ def __call__(self, *args): self.client.set(self.path, b"mwahaha") assert watcher.called is True - def test_watcher_repeat_delete(self): + def test_watcher_repeat_delete(self) -> None: a = [] ev = threading.Event() self.client.delete(self.path) @self.client.DataWatch(self.path) - def changed(val, stat): + def changed(val: bytes | None, stat: ZnodeStat | None) -> None: a.append(val) ev.set() @@ -258,14 +272,14 @@ def changed(val, stat): ev.clear() assert a == [None, b"blah", None, b"blah"] - def test_watcher_with_closing(self): + def test_watcher_with_closing(self) -> None: a = [] ev = threading.Event() self.client.delete(self.path) @self.client.DataWatch(self.path) - def changed(val, stat): + def changed(val: bytes | None, stat: ZnodeStat | None) -> None: a.append(val) ev.set() @@ -280,17 +294,18 @@ def changed(val, stat): class KazooChildrenWatcherTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooChildrenWatcherTests, self).setUp() self.path = "/" + uuid.uuid4().hex self.client.ensure_path(self.path) - def test_child_watcher(self): + def test_child_watcher(self) -> None: update = threading.Event() all_children = ["fred"] @self.client.ChildrenWatch(self.path) - def changed(children): + def changed(children: list[str] | None) -> None: + assert children is not None while all_children: all_children.pop() all_children.extend(children) @@ -309,16 +324,17 @@ def changed(children): update.wait(10) assert sorted(all_children) == ["george", "smith"] - def test_child_watcher_once(self): + def test_child_watcher_once(self) -> None: update = threading.Event() all_children = ["fred"] cwatch = self.client.ChildrenWatch(self.path) @cwatch - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -329,18 +345,21 @@ def changed(children): with pytest.raises(KazooException): @cwatch - def changed_again(children): + def changed_again(children: list[str] | None) -> None: update.set() - def test_child_watcher_with_event(self): + def test_child_watcher_with_event(self) -> None: update = threading.Event() - events = [True] + events: list[WatchedEvent | None | Literal[True]] = [True] @self.client.ChildrenWatch(self.path, send_event=True) - def changed(children, event): + def changed( + children: list[str] | None, event: WatchedEvent | None + ) -> bool | None: events.pop() events.append(event) update.set() + return None update.wait(10) assert events == [None] @@ -348,16 +367,19 @@ def changed(children, event): self.client.create(self.path + "/" + "smith") update.wait(10) + assert events[0] is not None + assert events[0] is not True assert events[0].type == EventType.CHILD update.clear() - def test_func_style_child_watcher(self): + def test_func_style_child_watcher(self) -> None: update = threading.Event() all_children = ["fred"] - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -376,20 +398,22 @@ def changed(children): update.wait(10) assert sorted(all_children) == ["george", "smith"] - def test_func_stops(self): + def test_func_stops(self) -> None: update = threading.Event() all_children = ["fred"] - fail_through = [] + fail_through: list[bool] = [] @self.client.ChildrenWatch(self.path) - def changed(children): + def changed(children: list[str] | None) -> bool | None: + assert children is not None while all_children: all_children.pop() all_children.extend(children) update.set() if fail_through: return False + return None # ? True? update.wait(10) assert all_children == [] @@ -405,19 +429,21 @@ def changed(children): update.wait(0.5) assert all_children == ["smith"] - def test_child_watcher_remove_session_watcher(self): + def test_child_watcher_remove_session_watcher(self) -> None: update = threading.Event() all_children = ["fred"] - fail_through = [] + fail_through: list[bool] = [] - def changed(children): + def changed(children: list[str] | None) -> bool | None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() if fail_through: return False + return None # ? children_watch = self.client.ChildrenWatch(self.path, changed) session_watcher = children_watch._session_watcher @@ -439,14 +465,15 @@ def changed(children): assert session_watcher not in self.client.state_listeners assert all_children == ["smith"] - def test_child_watch_session_loss(self): + def test_child_watch_session_loss(self) -> None: update = threading.Event() all_children = ["fred"] @self.client.ChildrenWatch(self.path) - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -464,14 +491,15 @@ def changed(children): update.wait(20) assert sorted(all_children) == ["george", "smith"] - def test_child_stop_on_session_loss(self): + def test_child_stop_on_session_loss(self) -> None: update = threading.Event() all_children = ["fred"] @self.client.ChildrenWatch(self.path, allow_session_lost=False) - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -495,16 +523,14 @@ def changed(children): class KazooPatientChildrenWatcherTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooPatientChildrenWatcherTests, self).setUp() self.path = "/" + uuid.uuid4().hex - def _makeOne(self, *args, **kwargs): - from kazoo.recipe.watchers import PatientChildrenWatch - + def _makeOne(self, *args: Any, **kwargs: Any) -> PatientChildrenWatch: return PatientChildrenWatch(*args, **kwargs) - def test_watch(self): + def test_watch(self) -> None: self.client.ensure_path(self.path) watcher = self._makeOne(self.client, self.path, 0.1) result = watcher.start() @@ -516,7 +542,7 @@ def test_watch(self): asy.get(timeout=1) assert asy.ready() is True - def test_exception(self): + def test_exception(self) -> None: from kazoo.exceptions import NoNodeError watcher = self._makeOne(self.client, self.path, 0.1) @@ -525,7 +551,7 @@ def test_exception(self): with pytest.raises(NoNodeError): result.get() - def test_watch_iterations(self): + def test_watch_iterations(self) -> None: self.client.ensure_path(self.path) watcher = self._makeOne(self.client, self.path, 0.5) result = watcher.start() diff --git a/kazoo/tests/util.py b/kazoo/tests/util.py index f7298718..c6211238 100644 --- a/kazoo/tests/util.py +++ b/kazoo/tests/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + ############################################################################## # # Copyright Zope Foundation and Contributors. @@ -16,46 +18,50 @@ import os import time +from typing import Any, Callable, Type + CI = os.environ.get("CI", False) -CI_ZK_VERSION = CI and os.environ.get("ZOOKEEPER_VERSION", None) -if CI_ZK_VERSION: - if "-" in CI_ZK_VERSION: - # Ignore pre-release markers like -alpha - CI_ZK_VERSION = CI_ZK_VERSION.split("-")[0] - CI_ZK_VERSION = tuple([int(n) for n in CI_ZK_VERSION.split(".")]) +CI_ZK_VERSION: tuple[int, ...] = tuple() +if CI: + has_version = os.environ.get("ZOOKEEPER_VERSION", "") + if has_version: + if "-" in has_version: + # Ignore pre-release markers like -alpha + has_version = has_version.split("-")[0] + CI_ZK_VERSION = tuple([int(n) for n in has_version.split(".")]) class Handler(logging.Handler): - def __init__(self, *names, **kw): + def __init__(self, *names: Any, **kw: Any): logging.Handler.__init__(self) self.names = names - self.records = [] + self.records: list[Any] = [] self.setLoggerLevel(**kw) - def setLoggerLevel(self, level=1): + def setLoggerLevel(self, level: int = 1) -> None: self.level = level - self.oldlevels = {} + self.oldlevels: dict[str, int] = {} - def emit(self, record): + def emit(self, record: Any) -> None: self.records.append(record) - def clear(self): + def clear(self) -> None: del self.records[:] - def install(self): + def install(self) -> None: for name in self.names: logger = logging.getLogger(name) self.oldlevels[name] = logger.level logger.setLevel(self.level) logger.addHandler(self) - def uninstall(self): + def uninstall(self) -> None: for name in self.names: logger = logging.getLogger(name) logger.setLevel(self.oldlevels[name]) logger.removeHandler(self) - def __str__(self): + def __str__(self) -> str: return "\n".join( [ ( @@ -78,7 +84,7 @@ def __str__(self): class InstalledHandler(Handler): - def __init__(self, *names, **kw): + def __init__(self, *names: Any, **kw: Any): Handler.__init__(self, *names, **kw) self.install() @@ -92,11 +98,11 @@ class TimeOutWaitingFor(Exception): def __init__( self, - timeout=None, - wait=None, - exception=None, - getnow=(lambda: time.monotonic), - getsleep=(lambda: time.sleep), + timeout: int | None = None, + wait: float | None = None, + exception: Type[Exception] = TimeOutWaitingFor, + getnow: Callable[[], Callable[[], float]] = (lambda: time.monotonic), + getsleep: Callable[[], Callable[[float], None]] = (lambda: time.sleep), ): if timeout is not None: self.timeout = timeout @@ -104,15 +110,21 @@ def __init__( if wait is not None: self.wait = wait - if exception is not None: - self.TimeOutWaitingFor = exception + self.exception = exception self.getnow = getnow self.getsleep = getsleep - def __call__(self, func=None, timeout=None, wait=None, message=None): - if func is None: - return lambda func: self(func, timeout, wait, message) + def __call__( + self, + func: Callable[[], Any], # | None = None, + timeout: float | None = None, + wait: float | None = None, + message: str | None = None, + ) -> None: + # if func is None: + # # Seriously WTF? + # return lambda func: self(func, timeout, wait, message) if func(): return @@ -131,9 +143,8 @@ def __call__(self, func=None, timeout=None, wait=None, message=None): if func(): return if now() > deadline: - raise self.TimeOutWaitingFor( - message or func.__doc__ or func.__name__ - ) + # raise self.TimeOutWaitingFor( + raise self.exception(message or func.__doc__ or func.__name__) wait = Wait() diff --git a/pyproject.toml b/pyproject.toml index db3890c5..869ed27c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,18 +6,9 @@ requires = [ [tool.black] line-length = 79 -target-version = ['py37', 'py38', 'py39', 'py310'] +# We need a later version of black for 312-314 +target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' -# 'extend-exclude' excludes files or directories in addition to the defaults -# A regex preceded with ^/ will apply only to files and directories -# in the root of the project. -# ( -# ^/foo.py # exclude a file named foo.py in the root of the project. -# | .*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the -# ) -# # project. -extend-exclude = ''' -''' [tool.pytest.ini_options] addopts = "-ra -v --color=yes" @@ -42,8 +33,12 @@ ignore_missing_imports = false # Disallow dynamic typing disallow_any_unimported = true disallow_any_expr = false -disallow_any_decorated = true -disallow_any_explicit = true +# FIXME disallow_any_decorated is disabled because it produces an error for almost +# every decorated function (at least in python3.8) +disallow_any_decorated = false +# FIXME disallow_any_explicit is disabled because it produces a lot of errors +# which are currently hard to fix, and we want to avoid code changes as much as possible. +disallow_any_explicit = false disallow_any_generics = true disallow_subclassing_any = true @@ -68,6 +63,8 @@ warn_unreachable = true # Miscellaneous strictness flags allow_untyped_globals = false allow_redefinition = false +# FIXME Disabled for now as we need to fix the X509 import. +#enable_error_code = ["deprecated"] local_partial_types = true implicit_reexport = false strict_concatenate = true @@ -81,71 +78,25 @@ hide_error_codes = false pretty = true color_output = true error_summary = true -show_absolute_path = true +show_absolute_path = false # Miscellaneous warn_unused_configs = true verbosity = 0 -# FIXME: As type annotations are introduced, please remove the appropriate -# ignore_errors flag below. New modules should NOT be added here! +# This is a temporary workaround for the fact that mypy can +# produce different errors in 3.8 and 3.14, and I want to avoid code changes +# as much as possible. +disable_error_code = [ 'unused-ignore' ] [[tool.mypy.overrides]] -module = [ - 'kazoo.client', - 'kazoo.exceptions', - 'kazoo.handlers.eventlet', - 'kazoo.handlers.gevent', - 'kazoo.handlers.threading', - 'kazoo.handlers.utils', - 'kazoo.hosts', - 'kazoo.interfaces', - 'kazoo.loggingsupport', - 'kazoo.protocol.connection', - 'kazoo.protocol.paths', - 'kazoo.protocol.serialization', - 'kazoo.protocol.states', - 'kazoo.recipe.barrier', - 'kazoo.recipe.cache', - 'kazoo.recipe.counter', - 'kazoo.recipe.election', - 'kazoo.recipe.lease', - 'kazoo.recipe.lock', - 'kazoo.recipe.partitioner', - 'kazoo.recipe.party', - 'kazoo.recipe.queue', - 'kazoo.recipe.watchers', - 'kazoo.retry', - 'kazoo.security', - 'kazoo.testing.common', - 'kazoo.testing.harness', - 'kazoo.tests.conftest', - 'kazoo.tests.test_barrier', - 'kazoo.tests.test_build', - 'kazoo.tests.test_cache', - 'kazoo.tests.test_client', - 'kazoo.tests.test_connection', - 'kazoo.tests.test_counter', - 'kazoo.tests.test_election', - 'kazoo.tests.test_eventlet_handler', - 'kazoo.tests.test_exceptions', - 'kazoo.tests.test_gevent_handler', - 'kazoo.tests.test_hosts', - 'kazoo.tests.test_interrupt', - 'kazoo.tests.test_lease', - 'kazoo.tests.test_lock', - 'kazoo.tests.test_partitioner', - 'kazoo.tests.test_party', - 'kazoo.tests.test_paths', - 'kazoo.tests.test_queue', - 'kazoo.tests.test_retry', - 'kazoo.tests.test_sasl', - 'kazoo.tests.test_security', - 'kazoo.tests.test_selectors_select', - 'kazoo.tests.test_threading_handler', - 'kazoo.tests.test_utils', - 'kazoo.tests.test_watchers', - 'kazoo.tests.util', - 'kazoo.version' -] -ignore_errors = true + module = ["eventlet.*"] + follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["gevent.thread"] + follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["puresasl.*"] + follow_untyped_imports = true diff --git a/setup.cfg b/setup.cfg index e3a7b880..5bb8c0fb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,19 +43,26 @@ zip_safe = false include_package_data = true packages = find: +[options.package_data] +kazoo = py.typed + [aliases] release = sdist bdist_wheel [egg_info] tag_build = dev -[bdist_wheel] -universal = true - [options.extras_require] dev = + black flake8 +# I have to pin this HERE because I can't persuade vscode to +# look in constraints.txt +other = + pyOpenSSL<26.2.0 + typing-extensions + test = objgraph pytest @@ -64,7 +71,7 @@ test = gevent>=1.2 ; implementation_name!='pypy' eventlet>=0.17.1 ; implementation_name!='pypy' pyjks - pyOpenSSL + %(other)s eventlet = eventlet>=0.17.1 @@ -80,8 +87,11 @@ docs = sphinx-autodoc-typehints>=1 typing = - mypy>=0.991 - pyOpenSSL + mypy==1.14.1 + types-gevent + types-mock + types-objgraph + types-pyjks alldeps = %(dev)s @@ -89,4 +99,5 @@ alldeps = %(gevent)s %(sasl)s %(docs)s + %(other)s %(typing)s diff --git a/tox.ini b/tox.ini index 567acd29..054fd9d0 100644 --- a/tox.ini +++ b/tox.ini @@ -62,6 +62,8 @@ basepython = python3 extras = alldeps deps = mypy - mypy: types-mock + types-mock + pyopenssl + pytest usedevelop = True commands = mypy --config-file {toxinidir}/pyproject.toml {toxinidir}/kazoo