From ebb5665ec975ef93a20bc64f0ce122aa0b1341ea Mon Sep 17 00:00:00 2001 From: Thomas Tanner Date: Sat, 16 May 2026 20:50:43 +0100 Subject: [PATCH] feat: Add type hinting This adds nearly full type hinting, and enables all but two of the mypy checks that were enabled in the mypy section in pyproject.toml. I hope to address the remaining checks, but it was proving tricky to address some of the 'typing.Any' uses. Also if you enable deprecation checks with --enable-error-code=deprecated You'll get warnings if you use deprecated options (at least as far as I was able to spot them!) This addresses #647 but there's not entirely strict checking as yet. Although I've tried very hard not to change any code for this, some `return None` statements have been added, which has caused the code coverage to drop as apparently those return paths were not covered by the testing. Also I've been less polite with the testing, and have changed the code in a number of places, including dropping test classes that did nothing, and adding asserts here and there to keep mypy happy as it sometimes doesn't spot a value could have been changed. I've added FIXMEs to address some of the more egregious type warning suppression that I did to avoid doing code changes, and intend to address those in another PR. A note: I've also had to suppress some errors in flake8 and use more general 'type: ignore' suppressions than I'd like because hound CI does something very very odd and tries to parse the text in the [] of the type: ignore[] comments as python, and generates some very peculiar messages. --- .coveragerc | 7 + .flake8 | 28 +- .gitignore | 5 +- constraints.txt | 6 + docs/testing.rst | 4 +- kazoo/client.py | 654 +++++++++++++++++++------- kazoo/exceptions.py | 21 +- kazoo/handlers/eventlet.py | 129 +++-- kazoo/handlers/gevent.py | 79 +++- kazoo/handlers/threading.py | 79 +++- kazoo/handlers/utils.py | 167 +++++-- kazoo/hosts.py | 10 +- kazoo/interfaces.py | 242 +++++++++- kazoo/protocol/connection.py | 289 +++++++++--- kazoo/protocol/paths.py | 12 +- kazoo/protocol/serialization.py | 272 +++++++---- kazoo/protocol/states.py | 66 ++- kazoo/py.typed | 0 kazoo/recipe/barrier.py | 41 +- kazoo/recipe/cache.py | 161 +++++-- kazoo/recipe/counter.py | 41 +- kazoo/recipe/election.py | 29 +- kazoo/recipe/lease.py | 66 +-- kazoo/recipe/lock.py | 155 ++++-- kazoo/recipe/partitioner.py | 112 +++-- kazoo/recipe/party.py | 57 ++- kazoo/recipe/queue.py | 59 ++- kazoo/recipe/watchers.py | 151 ++++-- kazoo/retry.py | 43 +- kazoo/security.py | 56 ++- kazoo/testing/common.py | 95 ++-- kazoo/testing/harness.py | 109 +++-- kazoo/tests/conftest.py | 4 +- kazoo/tests/test_barrier.py | 34 +- kazoo/tests/test_build.py | 10 +- kazoo/tests/test_cache.py | 170 ++++--- kazoo/tests/test_client.py | 465 +++++++++--------- kazoo/tests/test_connection.py | 127 ++--- kazoo/tests/test_counter.py | 25 +- kazoo/tests/test_election.py | 43 +- kazoo/tests/test_eventlet_handler.py | 152 ++---- kazoo/tests/test_exceptions.py | 14 +- kazoo/tests/test_gevent_handler.py | 107 ++--- kazoo/tests/test_hosts.py | 8 +- kazoo/tests/test_interrupt.py | 4 +- kazoo/tests/test_lease.py | 36 +- kazoo/tests/test_lock.py | 169 ++++--- kazoo/tests/test_partitioner.py | 73 +-- kazoo/tests/test_party.py | 20 +- kazoo/tests/test_paths.py | 34 +- kazoo/tests/test_queue.py | 64 +-- kazoo/tests/test_retry.py | 32 +- kazoo/tests/test_sasl.py | 34 +- kazoo/tests/test_security.py | 35 +- kazoo/tests/test_selectors_select.py | 30 +- kazoo/tests/test_threading_handler.py | 107 ++--- kazoo/tests/test_utils.py | 51 +- kazoo/tests/test_watchers.py | 148 +++--- kazoo/tests/util.py | 69 +-- pyproject.toml | 99 +--- setup.cfg | 23 +- tox.ini | 4 +- 62 files changed, 3468 insertions(+), 1968 deletions(-) create mode 100644 kazoo/py.typed diff --git a/.coveragerc b/.coveragerc index d84a6fc8b..c79b53748 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,10 @@ include = omit = kazoo/tests/* kazoo/testing/* + +# Note - this is a copy of the default exclusions from coverage 7.10.1 +[report] +exclude_lines = + #\s*(pragma|PRAGMA)[:\s]?\s*(no|NO)\s*(cover|COVER) + ^\s*(((async )?def .*?)?\)(\s*->.*?)?:\s*)?\.\.\.\s*(#|$) + if (typing\.)?TYPE_CHECKING: diff --git a/.flake8 b/.flake8 index ba8a3d67f..75277d346 100644 --- a/.flake8 +++ b/.flake8 @@ -1,14 +1,28 @@ [flake8] builtins = _ exclude = + docs/conf.py, .git, __pycache__, - .venv/,venv/, - .tox/, - build/,dist/,*egg, - docs/conf.py, - zookeeper/ -# See black's documentation for E203 + .venv/, + venv*/, + .tox*/, + build/, + dist/, + *egg, + zookeeper/, + max-line-length = 79 -extend-ignore = BLK100,E203 + +# See black's documentation for E203 +# +# I am not sure what version of flake8 hound is using but it gives a lot of +# undefined names for comments (F821) and redefinition of unused variables +# (F811). +# +# I've also had to supress F401 because you generally also need to specify +# type: ignore and something gets very upset when you have both in the same +# comment. +# +extend-ignore = BLK100,E203,F401,F811,F821 diff --git a/.gitignore b/.gitignore index 1c2b4b245..cc13a3453 100644 --- a/.gitignore +++ b/.gitignore @@ -29,14 +29,15 @@ zookeeper/ .idea .project .pydevproject -.tox +.tox*/ venv*/ /.settings /.metadata +__pycache__/ !.gitignore !.git-blame-ignore-revs -.vscode/settings.json +.vscode/ .*_cache/ coverage.xml diff --git a/constraints.txt b/constraints.txt index 51c48a101..58be5adbb 100644 --- a/constraints.txt +++ b/constraints.txt @@ -1,8 +1,14 @@ # Consistent testing environment. +# requirements.txt +eventlet>=0.17.1 ; implementation_name!='pypy' +gevent>=1.2 ; implementation_name!='pypy' + +# requirements-dev.txt black==22.10.0 coverage==6.3.2; python_version=="3.8" coverage==7.10.7; python_version > "3.8" flake8==5.0.2 +mypy==1.14.1 objgraph==3.5.0 pytest==6.2.5; python_version=="3.8" pytest==8.4.2; python_version > "3.8" diff --git a/docs/testing.rst b/docs/testing.rst index c98e35a89..b6d057beb 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -37,10 +37,10 @@ Example: from kazoo.testing import KazooTestHarness class MyTest(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: self.setup_zookeeper() - def tearDown(self): + def tearDown(self)-> None: self.teardown_zookeeper() def testmycode(self): diff --git a/kazoo/client.py b/kazoo/client.py index 3029d1c5f..893a6ca83 100644 --- a/kazoo/client.py +++ b/kazoo/client.py @@ -1,4 +1,7 @@ """Kazoo Zookeeper Client""" + +from __future__ import annotations + from collections import defaultdict, deque from functools import partial import inspect @@ -6,6 +9,22 @@ from os.path import split import re import warnings +from types import TracebackType +from typing import ( + cast, + overload, + Any, + Callable, + Deque, + Iterable, + Literal, + Optional, + Sequence, + Set, + TypedDict, + TYPE_CHECKING, +) +from typing_extensions import ParamSpec, Unpack, deprecated from kazoo.exceptions import ( AuthFailedError, @@ -63,6 +82,10 @@ from kazoo.recipe.queue import Queue, LockingQueue from kazoo.recipe.watchers import ChildrenWatch, DataWatch +if TYPE_CHECKING: + from kazoo.interfaces import Event, IAsyncResult, IHandler + from kazoo.protocol.states import ZnodeStat + CLOSED_STATES = ( KeeperState.EXPIRED_SESSION, @@ -89,6 +112,35 @@ ) +class LegacyRetryParams(TypedDict, total=False): + max_retries: int + retry_delay: float + retry_backoff: int + retry_max_delay: float + + +# Signature for functions called by add_listener +ListenerFunc = Callable[[KazooState], Optional[bool]] + +# Signatures for get, get_children and exists watches +WatchFunc = Callable[[WatchedEvent], Optional[bool]] + +GenericArgs = ParamSpec("GenericArgs") + + +# Kazoo retry parameters +class KazooRetryParams(TypedDict, total=False): + max_tries: int + delay: float + backoff: int + max_jitter: float + max_delay: float + ignore_expire: bool + sleep_func: Callable[[float], None] + deadline: float + interrupt: Callable[[], bool] + + class KazooClient(object): """An Apache Zookeeper Python client supporting alternate callback handlers and high-level functionality. @@ -100,29 +152,86 @@ class KazooClient(object): """ + @overload + def __init__( + self, + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: Iterable[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + ) -> None: + ... + + # FIXME This should be deprecated then killed + @overload + @deprecated( + "Passing retry configuration parameters directly to the client" + " is deprecated, please pass a configured retry object (using param" + " connection_retry or command_retry)" + ) + def __init__( + self, + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: Iterable[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + **kwargs: Unpack[LegacyRetryParams], + ) -> None: + ... + def __init__( self, - hosts="127.0.0.1:2181", - timeout=10.0, - client_id=None, - handler=None, - default_acl=None, - auth_data=None, - sasl_options=None, - read_only=None, - randomize_hosts=True, - connection_retry=None, - command_retry=None, - logger=None, - keyfile=None, - keyfile_password=None, - certfile=None, - ca=None, - use_ssl=False, - verify_certs=True, - check_hostname=False, - **kwargs, - ): + hosts: str | list[str] = "127.0.0.1:2181", + timeout: float = 10.0, + client_id: tuple[int | None, bytes] | None = None, + handler: IHandler | None = None, + default_acl: Sequence[ACL] | None = None, + auth_data: Iterable[tuple[str, str]] | None = None, + sasl_options: dict[str, str] | None = None, + read_only: bool | None = None, + randomize_hosts: bool = True, + connection_retry: KazooRetry | KazooRetryParams | None = None, + command_retry: KazooRetry | KazooRetryParams | None = None, + logger: logging.Logger | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + certfile: str | None = None, + ca: str | None = None, + use_ssl: bool = False, + verify_certs: bool = True, + check_hostname: bool = False, + **kwargs: Unpack[LegacyRetryParams], + ) -> None: """Create a :class:`KazooClient` instance. All time arguments are in seconds. @@ -231,11 +340,18 @@ def __init__( "not the class: %s" % self.handler ) - self.auth_data = auth_data if auth_data else set([]) + self.auth_data = set(auth_data if auth_data else []) self.default_acl = default_acl self.randomize_hosts = randomize_hosts - self.hosts = None - self.chroot = None + # FIXME Note: hosts and chroot are set by set_hosts, which also checks + # for chroot changes at runtime, so we initialize them to None here to + # avoid confusion with the empty string that set_hosts would set them + # to. This is massively hacky as set_hosts is only called from here + # anyway, but I want to make this change minimally invasive. + # we should really do self.hosts, self.chroot = self.set_hosts(hosts) + # and have set_hosts return the hosts and chroot + self.hosts: list[tuple[str, int]] = None # type: ignore[assignment] + self.chroot: str = None # type: ignore[assignment] self.set_hosts(hosts) self.use_ssl = use_ssl @@ -247,11 +363,15 @@ def __init__( self.ca = ca # Curator like simplified state tracking, and listeners for # state transitions - self._state = KeeperState.CLOSED - self.state = KazooState.LOST - self.state_listeners = set() - self._child_watchers = defaultdict(set) - self._data_watchers = defaultdict(set) + self._state: KeeperState = KeeperState.CLOSED + self.state: KazooState = KazooState.LOST + self.state_listeners: set[ListenerFunc] = set() + self._child_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) + self._data_watchers: defaultdict[str, Set[WatchFunc]] = defaultdict( + set + ) self._reset() self.read_only = read_only @@ -272,7 +392,14 @@ def __init__( self._stopped.set() self._writer_stopped.set() - self.retry = self._conn_retry = None + # FIXME This is kind of gross but we need to set these to something so + # that the type checker will understand that they are set by the time + # they are used and that they have the right type. + # We would do better to use a few variables/functions instead of + # overloading self.retry but this is a bit less invasive to the code + # and the type checker can understand it with a few hacks + self.retry: KazooRetry = None # type: ignore[assignment] + self._conn_retry: KazooRetry = None # type: ignore[assignment] if type(connection_retry) is dict: self._conn_retry = KazooRetry(**connection_retry) @@ -299,7 +426,11 @@ def __init__( ) if self.retry is None or self._conn_retry is None: - old_retry_keys = dict(_RETRY_COMPAT_DEFAULTS) + # Note: because of the hacks at line 280, mypy thinks this is + # unreachable + old_retry_keys = dict( # type: ignore[unreachable] + _RETRY_COMPAT_DEFAULTS + ) for key in old_retry_keys: try: old_retry_keys[key] = kwargs.pop(key) @@ -320,11 +451,13 @@ def __init__( if self._conn_retry is None: self._conn_retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) if self.retry is None: self.retry = KazooRetry( - sleep_func=self.handler.sleep_func, **retry_keys + sleep_func=self.handler.sleep_func, + **retry_keys, ) # Managing legacy SASL options @@ -372,10 +505,21 @@ def __init__( # to avoid shared retry counts self._retry = self.retry - def _retry(*args, **kwargs): - return self._retry.copy()(*args, **kwargs) - - self.retry = _retry + def _retry( + func: Callable[GenericArgs, KazooRetry.RETRY_RETURN], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, + ) -> KazooRetry.RETRY_RETURN: + return self._retry.copy()(func, *args, **kwargs) + + # FIXME + # (expression has type "Callable[[VarArg(Any), KwArg(Any)], Any]", + # variable has type "KazooRetry") so basically self.retry needs to be + # set to that and then the type checker will understand that + # self.retry.copy() is a valid call. This is just a mess and needs the + # code rearranging to be more mypy friendly but this is the least + # invasive way to do it for now + self.retry = _retry # type: ignore[assignment] self.Barrier = partial(Barrier, self) self.Counter = partial(Counter, self) @@ -402,18 +546,18 @@ def _retry(*args, **kwargs): % (kwargs.keys(),) ) - def _reset(self): + def _reset(self) -> None: """Resets a variety of client states for a new connection.""" - self._queue = deque() - self._pending = deque() + self._queue: Deque[tuple[Any, IAsyncResult]] = deque() + self._pending: Deque[tuple[Any, IAsyncResult, int]] = deque() self._reset_watchers() self._reset_session() self.last_zxid = 0 - self._protocol_version = None + self._protocol_version: int | None = None - def _reset_watchers(self): - watchers = [] + def _reset_watchers(self) -> None: + watchers: list[WatchFunc] = [] for child_watchers in self._child_watchers.values(): watchers.extend(child_watchers) @@ -427,12 +571,12 @@ def _reset_watchers(self): for watch in watchers: self.handler.dispatch_callback(Callback("watch", watch, (ev,))) - def _reset_session(self): + def _reset_session(self) -> None: self._session_id = None self._session_passwd = b"\x00" * 16 @property - def client_state(self): + def client_state(self) -> KeeperState: """Returns the last Zookeeper client state This is the non-simplified state information and is generally @@ -442,7 +586,7 @@ def client_state(self): return self._state @property - def client_id(self): + def client_id(self) -> tuple[int | None, bytes] | None: """Returns the client id for this Zookeeper session if connected. @@ -455,12 +599,16 @@ def client_id(self): return None @property - def connected(self): + def connected(self) -> bool: """Returns whether the Zookeeper connection has been established.""" return self._live.is_set() - def set_hosts(self, hosts, randomize_hosts=None): + def set_hosts( + self, + hosts: str | list[str], + randomize_hosts: bool | None = None, + ) -> None: """sets the list of hosts used by this client. This function accepts the same format hosts parameter as the init @@ -504,7 +652,7 @@ def set_hosts(self, hosts, randomize_hosts=None): self.chroot = new_chroot - def add_listener(self, listener): + def add_listener(self, listener: ListenerFunc) -> None: """Add a function to be called for connection state changes. This function will be called with a @@ -519,15 +667,20 @@ def add_listener(self, listener): should be used so that the listener can return immediately. """ - if not (listener and callable(listener)): + # This check should be unnecessary but protects against people who are + # not using type checkers and accidentally passing in something that + # isn't callable. It should be removed. + if not ( + listener and callable(listener) # type: ignore[truthy-function] + ): raise ConfigurationError("listener must be callable") self.state_listeners.add(listener) - def remove_listener(self, listener): + def remove_listener(self, listener: ListenerFunc) -> None: """Remove a listener function""" self.state_listeners.discard(listener) - def _make_state_change(self, state): + def _make_state_change(self, state: KazooState) -> None: # skip if state is current if self.state == state: return @@ -544,7 +697,7 @@ def _make_state_change(self, state): except Exception: self.logger.exception("Error in connection state listener") - def _session_callback(self, state): + def _session_callback(self, state: KeeperState) -> None: if state == self._state: return @@ -581,9 +734,10 @@ def _session_callback(self, state): self._make_state_change(KazooState.SUSPENDED) self._reset_watchers() - def _notify_pending(self, state): + def _notify_pending(self, state: KeeperState) -> None: """Used to clear a pending response queue and request queue during connection drops.""" + exc: KazooException if state == KeeperState.AUTH_FAILED: exc = AuthFailedError() elif state == KeeperState.EXPIRED_SESSION: @@ -607,7 +761,7 @@ def _notify_pending(self, state): except IndexError: break - def _safe_close(self): + def _safe_close(self) -> None: self.handler.stop() timeout = self._session_timeout // 1000 if timeout < 10: @@ -618,7 +772,9 @@ def _safe_close(self): "and wouldn't close after %s seconds" % timeout ) - def _call(self, request, async_object): + def _call( + self, request: object, async_object: IAsyncResult + ) -> bool | None: """Ensure the client is in CONNECTED or SUSPENDED state and put the request in the queue if it is. @@ -647,14 +803,16 @@ def _call(self, request, async_object): async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None try: write_sock.send(b"\0") except: # NOQA async_object.set_exception( ConnectionClosedError("Connection has been closed") ) + return None - def start(self, timeout=15): + def start(self, timeout: float = 15.0) -> None: """Initiate connection to ZK. :param timeout: Time in seconds to wait for connection to @@ -678,7 +836,7 @@ def start(self, timeout=15): "should be created before normal use." ) - def start_async(self): + def start_async(self) -> Event: """Asynchronously initiate connection to ZK. :returns: An event object that can be checked to see if the @@ -705,7 +863,7 @@ def start_async(self): self._connection.start() return self._live - def stop(self): + def stop(self) -> None: """Gracefully stop this Zookeeper session. This method can be called while a reconnection attempt is in @@ -721,18 +879,22 @@ def stop(self): return self._stopped.set() - self._queue.append((CloseInstance, None)) + self._queue.append((CloseInstance, cast("IAsyncResult", None))) try: - self._connection._write_sock.send(b"\0") + # This assert should never fail since the connection should + # have been started but I'm not sure how to persaude mypy of that + self._connection._write_sock.send( # type: ignore[union-attr] + b"\0" + ) finally: self._safe_close() - def restart(self): + def restart(self) -> None: """Stop and restart the Zookeeper session.""" self.stop() self.start() - def close(self): + def close(self) -> None: """Free any resources held by the client. This method should be called on a stopped client before it is @@ -742,7 +904,7 @@ def close(self): """ self._connection.close() - def command(self, cmd=b"ruok"): + def command(self, cmd: bytes = b"ruok") -> str: """Sent a management command to the current ZK server. Examples are `ruok`, `envi` or `stat`. @@ -761,8 +923,18 @@ def command(self, cmd=b"ruok"): if not self._live.is_set(): raise ConnectionLoss("No connection to server") - peer = self._connection._socket.getpeername()[:2] - peer_host = self._connection._socket.getpeername()[1] + # Need a way of persauding mypy that the connection is live and thus + # the socket is not None + peer = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + :2 + ] + ) + peer_host = ( + self._connection._socket.getpeername()[ # type: ignore[union-attr] + 1 + ] + ) sock = self.handler.create_connection( peer, hostname=peer_host, @@ -780,7 +952,7 @@ def command(self, cmd=b"ruok"): sock.close() return result.decode("utf-8", "replace") - def server_version(self, retries=3): + def server_version(self, retries: int = 3) -> tuple[int, ...]: """Get the version of the currently connected ZK server. :returns: The server version, for example (3, 4, 3). @@ -790,7 +962,7 @@ def server_version(self, retries=3): """ - def _try_fetch(): + def _try_fetch() -> tuple[int, ...] | None: data = self.command(b"envi") data_parsed = {} for line in data.splitlines(): @@ -804,13 +976,19 @@ def _try_fetch(): if k: data_parsed[k] = v version = data_parsed.get(ENVI_VERSION_KEY, "") - version_digits = ENVI_VERSION.match(version).group(1) + # FIXME If you get an unexpected answer, you'll crash - not + # changing the code, so just ignoring the type error + version_digits = ENVI_VERSION.match( + version + ).group( # type: ignore[union-attr] + 1 + ) try: return tuple([int(d) for d in version_digits.split(".")]) except ValueError: return None - def _is_valid(version): + def _is_valid(version: tuple[int, ...] | None) -> bool: # All zookeeper versions should have at least major.minor # version numbers; if we get one that doesn't it is likely not # correct and was truncated... @@ -818,21 +996,29 @@ def _is_valid(version): return True return False + # FIXME A better way of doing this would be to put the initial + # _try_fetch in the loop and inline _is_valid but I want to minimise + # code changes + # Try 1 + retries amount of times to get a version that we know # will likely be acceptable... version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + # and the next 2 suppress should include return-value + # but hound is broken + return version # type: ignore for _i in range(0, retries): version = _try_fetch() if _is_valid(version): - return version + # mypy doesn't recognise that _is_valid guarantees this + return version # type: ignore raise KazooException( "Unable to fetch useable server" " version after trying %s times" % (1 + max(0, retries)) ) - def add_auth(self, scheme, credential): + def add_auth(self, scheme: str, credential: str) -> bool: """Send credentials to server. :param scheme: authentication scheme (default supported: @@ -847,9 +1033,9 @@ def add_auth(self, scheme, credential): the session state will be set to AUTH_FAILED as well. """ - return self.add_auth_async(scheme, credential).get() + return cast("bool", self.add_auth_async(scheme, credential).get()) - def add_auth_async(self, scheme, credential): + def add_auth_async(self, scheme: str, credential: str) -> IAsyncResult: """Asynchronously send credentials to server. Takes the same arguments as :meth:`add_auth`. @@ -868,7 +1054,7 @@ def add_auth_async(self, scheme, credential): self._call(Auth(0, scheme, credential), async_result) return async_result - def unchroot(self, path): + def unchroot(self, path: str) -> str: """Strip the chroot if applicable from the path.""" if not self.chroot: return path @@ -879,7 +1065,7 @@ def unchroot(self, path): else: return path - def sync_async(self, path): + def sync_async(self, path: str) -> IAsyncResult: """Asynchronous sync. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -888,10 +1074,10 @@ def sync_async(self, path): async_result = self.handler.async_result() @wrap(async_result) - def _sync_completion(result): + def _sync_completion(result: IAsyncResult) -> str: return self.unchroot(result.get()) - def _do_sync(): + def _do_sync() -> None: result = self.handler.async_result() self._call(Sync(_prefix_root(self.chroot, path)), result) result.rawlink(_sync_completion) @@ -899,7 +1085,7 @@ def _do_sync(): _do_sync() return async_result - def sync(self, path): + def sync(self, path: str) -> str: """Sync, blocks until response is acknowledged. Flushes channel between process and leader. @@ -913,18 +1099,44 @@ def sync(self, path): .. versionadded:: 0.5 """ - return self.sync_async(path).get() + return cast("str", self.sync_async(path).get()) + + @overload + def create( + self, + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[False] = False, + ) -> str: + ... + + @overload + def create( + self, + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: Literal[True] = True, + ) -> tuple[str, ZnodeStat]: + ... def create( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> str | tuple[str, ZnodeStat]: """Create a node with the given value as its data. Optionally set an ACL on the node. @@ -1003,26 +1215,29 @@ def create( The `include_data` option. """ acl = acl or self.default_acl - return self.create_async( - path, - value, - acl=acl, - ephemeral=ephemeral, - sequence=sequence, - makepath=makepath, - include_data=include_data, - ).get() + return cast( + "str | tuple[str, ZnodeStat]", + self.create_async( + path, + value, + acl=acl, + ephemeral=ephemeral, + sequence=sequence, + makepath=makepath, + include_data=include_data, + ).get(), + ) def create_async( self, - path, - value=b"", - acl=None, - ephemeral=False, - sequence=False, - makepath=False, - include_data=False, - ): + path: str, + value: bytes | None = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + makepath: bool = False, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously create a ZNode. Takes the same arguments as :meth:`create`. @@ -1066,11 +1281,16 @@ def create_async( async_result = self.handler.async_result() @capture_exceptions(async_result) - def do_create(): + def do_create() -> None: result = self._create_async_inner( path, value, - acl, + # The way acl is constructed ends up confusing mypy, which + # thinks that acl can be None here, even though the code + # above ensures that if acl is None, it gets set to + # OPEN_ACL_UNSAFE, so we ignore the type error here. + # behaves differently in python3.8 and python3.14, sigh. + acl, # type: ignore[arg-type] flags, trailing=sequence, include_data=include_data, @@ -1078,12 +1298,14 @@ def do_create(): result.rawlink(create_completion) @capture_exceptions(async_result) - def retry_completion(result): + def retry_completion(result: IAsyncResult) -> None: result.get() do_create() @wrap(async_result) - def create_completion(result): + def create_completion( + result: IAsyncResult, + ) -> str | tuple[str, ZnodeStat] | None: try: if include_data: new_path, stat = result.get() @@ -1098,18 +1320,22 @@ def create_completion(result): else: parent, _ = split(path) self.ensure_path_async(parent, acl).rawlink(retry_completion) + return None do_create() return async_result def _create_async_inner( - self, path, value, acl, flags, trailing=False, include_data=False - ): + self, + path: str, + value: bytes | None, + acl: Sequence[ACL], + flags: int, + trailing: bool = False, + include_data: bool = False, + ) -> IAsyncResult: async_result = self.handler.async_result() - if include_data: - opcode = Create2 - else: - opcode = Create + opcode = Create2 if include_data else Create call_result = self._call( opcode( @@ -1126,19 +1352,24 @@ def _create_async_inner( # exception upwards to the do_create function in # KazooClient.create so that it gets set on the correct # async_result object - raise async_result.exception + # Note: Do we actually need call_result? It seems like we could + # just check the state of the exception, and avoid the typing + # stuff. + raise async_result.exception # type: ignore[misc] return async_result - def ensure_path(self, path, acl=None): + def ensure_path(self, path: str, acl: Sequence[ACL] | None = None) -> bool: """Recursively create a path if it doesn't exist. :param path: Path of node. :param acl: Permissions for node. """ - return self.ensure_path_async(path, acl).get() + return cast("bool", self.ensure_path_async(path, acl).get()) - def ensure_path_async(self, path, acl=None): + def ensure_path_async( + self, path: str, acl: Sequence[ACL] | None = None + ) -> IAsyncResult: """Recursively create a path asynchronously if it doesn't exist. Takes the same arguments as :meth:`ensure_path`. @@ -1151,19 +1382,21 @@ def ensure_path_async(self, path, acl=None): async_result = self.handler.async_result() @wrap(async_result) - def create_completion(result): + def create_completion(result: IAsyncResult) -> bool: try: - return result.get() + return cast("bool", result.get()) except NodeExistsError: return True @capture_exceptions(async_result) - def prepare_completion(next_path, result): + def prepare_completion(next_path: str, result: IAsyncResult) -> None: result.get() self.create_async(next_path, acl=acl).rawlink(create_completion) @wrap(async_result) - def exists_completion(path, result): + def exists_completion( + path: str, result: IAsyncResult + ) -> Literal[True] | None: if result.get(): return True parent, node = split(path) @@ -1173,12 +1406,15 @@ def exists_completion(path, result): ) else: self.create_async(path, acl=acl).rawlink(create_completion) + return None self.exists_async(path).rawlink(partial(exists_completion, path)) return async_result - def exists(self, path, watch=None): + def exists( + self, path: str, watch: WatchFunc | None = None + ) -> ZnodeStat | None: """Check if a node exists. If a watch is provided, it will be left on the node with the @@ -1198,9 +1434,13 @@ def exists(self, path, watch=None): returns a non-zero error code. """ - return self.exists_async(path, watch=watch).get() + return cast( + "ZnodeStat | None", self.exists_async(path, watch=watch).get() + ) - def exists_async(self, path, watch=None): + def exists_async( + self, path: str, watch: WatchFunc | None = None + ) -> IAsyncResult: """Asynchronously check if a node exists. Takes the same arguments as :meth:`exists`. @@ -1218,7 +1458,9 @@ def exists_async(self, path, watch=None): ) return async_result - def get(self, path, watch=None): + def get( + self, path: str, watch: WatchFunc | None = None + ) -> tuple[bytes, ZnodeStat]: """Get the value of a node. If a watch is provided, it will be left on the node with the @@ -1241,9 +1483,13 @@ def get(self, path, watch=None): returns a non-zero error code """ - return self.get_async(path, watch=watch).get() + return cast( + "tuple[bytes, ZnodeStat]", self.get_async(path, watch=watch).get() + ) - def get_async(self, path, watch=None): + def get_async( + self, path: str, watch: WatchFunc | None = None + ) -> IAsyncResult: """Asynchronously get the value of a node. Takes the same arguments as :meth:`get`. @@ -1261,7 +1507,30 @@ def get_async(self, path, watch=None): ) return async_result - def get_children(self, path, watch=None, include_data=False): + @overload + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: Literal[False] = False, + ) -> list[str]: + ... + + @overload + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: Literal[True] = True, + ) -> tuple[list[str], ZnodeStat]: + ... + + def get_children( + self, + path: str, + watch: WatchFunc | None = None, + include_data: bool = False, + ) -> list[str] | tuple[list[str], ZnodeStat]: """Get a list of child nodes of a path. If a watch is provided it will be left on the node with the @@ -1295,11 +1564,19 @@ def get_children(self, path, watch=None, include_data=False): The `include_data` option. """ - return self.get_children_async( - path, watch=watch, include_data=include_data - ).get() + return cast( + "list[str] | tuple[list[str], ZnodeStat]", + self.get_children_async( + path, watch=watch, include_data=include_data + ).get(), + ) - def get_children_async(self, path, watch=None, include_data=False): + def get_children_async( + self, + path: str, + watch: WatchFunc | None = None, + include_data: bool = False, + ) -> IAsyncResult: """Asynchronously get a list of child nodes of a path. Takes the same arguments as :meth:`get_children`. @@ -1314,6 +1591,8 @@ def get_children_async(self, path, watch=None, include_data=False): raise TypeError("Invalid type for 'include_data' (bool expected)") async_result = self.handler.async_result() + # FIXME? Do this as req = getc2 if include_data else getc + req: GetChildren | GetChildren2 if include_data: req = GetChildren2(_prefix_root(self.chroot, path), watch) else: @@ -1321,7 +1600,7 @@ def get_children_async(self, path, watch=None, include_data=False): self._call(req, async_result) return async_result - def get_acls(self, path): + def get_acls(self, path: str) -> tuple[list[ACL], ZnodeStat]: """Return the ACL and stat of the node of the given path. :param path: Path of the node. @@ -1339,9 +1618,11 @@ def get_acls(self, path): .. versionadded:: 0.5 """ - return self.get_acls_async(path).get() + return cast( + "tuple[list[ACL], ZnodeStat]", self.get_acls_async(path).get() + ) - def get_acls_async(self, path): + def get_acls_async(self, path: str) -> IAsyncResult: """Return the ACL and stat of the node of the given path. Takes the same arguments as :meth:`get_acls`. @@ -1355,7 +1636,9 @@ def get_acls_async(self, path): self._call(GetACL(_prefix_root(self.chroot, path)), async_result) return async_result - def set_acls(self, path, acls, version=-1): + def set_acls( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> ZnodeStat: """Set the ACL for the node of the given path. Set the ACL for the node of the given path if such a node @@ -1382,9 +1665,13 @@ def set_acls(self, path, acls, version=-1): .. versionadded:: 0.5 """ - return self.set_acls_async(path, acls, version).get() + return cast( + "ZnodeStat", self.set_acls_async(path, acls, version).get() + ) - def set_acls_async(self, path, acls, version=-1): + def set_acls_async( + self, path: str, acls: Sequence[ACL], version: int = -1 + ) -> IAsyncResult: """Set the ACL for the node of the given path. Takes the same arguments as :meth:`set_acls`. @@ -1407,7 +1694,9 @@ def set_acls_async(self, path, acls, version=-1): ) return async_result - def set(self, path, value, version=-1): + def set( + self, path: str, value: bytes | None, version: int = -1 + ) -> ZnodeStat: """Set the value of a node. If the version of the node being updated is newer than the @@ -1440,9 +1729,11 @@ def set(self, path, value, version=-1): returns a non-zero error code. """ - return self.set_async(path, value, version).get() + return cast("ZnodeStat", self.set_async(path, value, version).get()) - def set_async(self, path, value, version=-1): + def set_async( + self, path: str, value: bytes | None, version: int = -1 + ) -> IAsyncResult: """Set the value of a node. Takes the same arguments as :meth:`set`. @@ -1463,7 +1754,7 @@ def set_async(self, path, value, version=-1): ) return async_result - def transaction(self): + def transaction(self) -> TransactionRequest: """Create and return a :class:`TransactionRequest` object Creates a :class:`TransactionRequest` object. A Transaction can @@ -1480,7 +1771,12 @@ def transaction(self): """ return TransactionRequest(self) - def delete(self, path, version=-1, recursive=False): + # This should not return anything. No return value is documented, and the + # two called functions return different things. AFAICT. But for now I + # want to minimise code changes + def delete( + self, path: str, version: int = -1, recursive: bool = False + ) -> Any: """Delete a node. The call will succeed if such a node exists, and the given @@ -1518,7 +1814,7 @@ def delete(self, path, version=-1, recursive=False): else: return self.delete_async(path, version).get() - def delete_async(self, path, version=-1): + def delete_async(self, path: str, version: int = -1) -> IAsyncResult: """Asynchronously delete a node. Takes the same arguments as :meth:`delete`, with the exception of `recursive`. @@ -1535,7 +1831,7 @@ def delete_async(self, path, version=-1): ) return async_result - def _delete_recursive(self, path): + def _delete_recursive(self, path: str) -> Literal[True] | None: try: children = self.get_children(path) except NoNodeError: @@ -1553,8 +1849,15 @@ def _delete_recursive(self, path): self.delete(path) except NoNodeError: # pragma: nocover pass + return None - def reconfig(self, joining, leaving, new_members, from_config=-1): + def reconfig( + self, + joining: str | None, + leaving: str | None, + new_members: str | None, + from_config: int = -1, + ) -> tuple[bytes, ZnodeStat]: """Reconfig a cluster. This call will succeed if the cluster was reconfigured accordingly. @@ -1625,9 +1928,15 @@ def reconfig(self, joining, leaving, new_members, from_config=-1): result = self.reconfig_async( joining, leaving, new_members, from_config ) - return result.get() + return cast("tuple[bytes, ZnodeStat]", result.get()) - def reconfig_async(self, joining, leaving, new_members, from_config): + def reconfig_async( + self, + joining: str | None, + leaving: str | None, + new_members: str | None, + from_config: int, + ) -> IAsyncResult: """Asynchronously reconfig a cluster. Takes the same arguments as :meth:`reconfig`. @@ -1674,14 +1983,19 @@ class TransactionRequest(object): """ - def __init__(self, client): + def __init__(self, client: KazooClient): self.client = client - self.operations = [] + self.operations: list[Any] = [] self.committed = False def create( - self, path, value=b"", acl=None, ephemeral=False, sequence=False - ): + self, + path: str, + value: bytes = b"", + acl: Sequence[ACL] | None = None, + ephemeral: bool = False, + sequence: bool = False, + ) -> None: """Add a create ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.create`, with the exception of `makepath`. @@ -1718,7 +2032,7 @@ def create( None, ) - def delete(self, path, version=-1): + def delete(self, path: str, version: int = -1) -> None: """Add a delete ZNode to the transaction. Takes the same arguments as :meth:`KazooClient.delete`, with the exception of `recursive`. @@ -1730,7 +2044,7 @@ def delete(self, path, version=-1): raise TypeError("Invalid type for 'version' (int expected)") self._add(Delete(_prefix_root(self.client.chroot, path), version)) - def set_data(self, path, value, version=-1): + def set_data(self, path: str, value: bytes, version: int = -1) -> None: """Add a set ZNode value to the transaction. Takes the same arguments as :meth:`KazooClient.set`. @@ -1745,7 +2059,7 @@ def set_data(self, path, value, version=-1): SetData(_prefix_root(self.client.chroot, path), value, version) ) - def check(self, path, version): + def check(self, path: str, version: int) -> None: """Add a Check Version to the transaction. This command will fail and abort a transaction if the path @@ -1760,7 +2074,7 @@ def check(self, path, version): CheckVersion(_prefix_root(self.client.chroot, path), version) ) - def commit_async(self): + def commit_async(self) -> IAsyncResult: """Commit the transaction asynchronously. :rtype: :class:`~kazoo.interfaces.IAsyncResult` @@ -1772,28 +2086,38 @@ def commit_async(self): self.client._call(Transaction(self.operations), async_object) return async_object - def commit(self): + def commit(self) -> list[Any]: """Commit the transaction. :returns: A list of the results for each operation in the transaction. """ - return self.commit_async().get() + return cast("list[Any]", self.commit_async().get()) - def __enter__(self): + def __enter__(self) -> TransactionRequest: return self - def __exit__(self, exc_type, exc_value, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: """Commit and cleanup accumulated transaction data.""" if not exc_type: self.commit() + return None - def _check_tx_state(self): + def _check_tx_state(self) -> None: if self.committed: raise ValueError("Transaction already committed") - def _add(self, request, post_processor=None): + def _add( + self, + request: Any, + post_processor: Callable[[Any], Any] | None = None, + ) -> None: self._check_tx_state() self.client.logger.log(BLATHER, "Added %r to %r", request, self) self.operations.append(request) diff --git a/kazoo/exceptions.py b/kazoo/exceptions.py index b24c697cb..ae9341df1 100644 --- a/kazoo/exceptions.py +++ b/kazoo/exceptions.py @@ -1,6 +1,9 @@ """Kazoo Exceptions""" +from __future__ import annotations + from collections import defaultdict +from typing import Callable, Type class KazooException(Exception): @@ -51,17 +54,25 @@ class SASLException(KazooException): """ -def _invalid_error_code(): +def _invalid_error_code() -> Type[ZookeeperError]: raise RuntimeError("Invalid error code") -EXCEPTIONS = defaultdict(_invalid_error_code) +EXCEPTIONS: defaultdict[int, Type[ZookeeperError]] = defaultdict( + _invalid_error_code +) -def _zookeeper_exception(code): - def decorator(klass): +def _zookeeper_exception( + code: int, +) -> Callable[[Type[ZookeeperError]], Type[ZookeeperError]]: + def decorator(klass: Type[ZookeeperError]) -> Type[ZookeeperError]: EXCEPTIONS[code] = klass - klass.code = code + # Unfortunately there is currently no good of doing the assignment here + # in a way that type checkers would allow. It's a known problem (see + # https://discuss.python.org/t/how-to-type-hint-a-class-decorator/63010 + # ) + klass.code = code # type: ignore[attr-defined] return klass return decorator diff --git a/kazoo/handlers/eventlet.py b/kazoo/handlers/eventlet.py index 8869cc570..ee8ed47d6 100644 --- a/kazoo/handlers/eventlet.py +++ b/kazoo/handlers/eventlet.py @@ -1,10 +1,14 @@ """A eventlet based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import contextlib import logging +from typing import cast, Any, Generator, TYPE_CHECKING + import eventlet from eventlet.green import socket as green_socket from eventlet.green import time as green_time @@ -15,6 +19,18 @@ from kazoo.handlers import utils from kazoo.handlers.utils import selector_select +if TYPE_CHECKING: + from kazoo.interfaces import ( + Event, + FdLike, + IHandler, + Lockable, + ReentrantLock, + Socket, + ) + from kazoo.protocol.states import Callback + + LOG = logging.getLogger(__name__) # sentinel objects @@ -22,17 +38,17 @@ @contextlib.contextmanager -def _yield_before_after(): +def _yield_before_after() -> Generator[None, None, None]: # Yield to any other co-routines... # # See: http://eventlet.net/doc/modules/greenthread.html # for how this zero sleep is really a cooperative yield to other potential # co-routines... - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] try: yield finally: - eventlet.sleep(0) + eventlet.sleep(0) # type: ignore[no-untyped-call] class TimeoutError(Exception): @@ -42,12 +58,15 @@ class TimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: IHandler): super(AsyncResult, self).__init__( - handler, green_threading.Condition, TimeoutError + handler, + green_threading.Condition, # type: ignore[attr-defined] + TimeoutError, ) +# FIXME This should inherit from IHandler class SequentialEventletHandler(object): """Eventlet handler for sequentially executing callbacks. @@ -81,26 +100,35 @@ class SequentialEventletHandler(object): queue_impl = green_queue.LightQueue queue_empty = green_queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialEventletHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() - self._workers = [] + self.callback_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self.completion_queue = ( + self.queue_impl() # type: ignore[no-untyped-call] + ) + self._workers: list[ # type: ignore[name-defined] + tuple[ + eventlet.GreenThread, + green_queue.LightQueue, + ] + ] = [] self._started = False @staticmethod - def sleep_func(wait): - green_time.sleep(wait) + def sleep_func(wait: float) -> None: + green_time.sleep(wait) # type: ignore[attr-defined, no-untyped-call] @property - def running(self): + def running(self) -> bool: return self._started timeout_exception = TimeoutError - def _process_completion_queue(self): + def _process_completion_queue(self) -> None: while True: - cb = self.completion_queue.get() + cb = self.completion_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -114,9 +142,9 @@ def _process_completion_queue(self): finally: del cb # release before possible idle - def _process_callback_queue(self): + def _process_callback_queue(self) -> None: while True: - cb = self.callback_queue.get() + cb = self.callback_queue.get() # type: ignore[no-untyped-call] if cb is _STOP: break try: @@ -130,58 +158,83 @@ def _process_callback_queue(self): finally: del cb # release before possible idle - def start(self): + def start(self) -> None: if not self._started: # Spawn our worker threads, we have # - A callback worker for watch events to be called # - A completion worker for completion events to be called - w = eventlet.spawn(self._process_completion_queue) + w = eventlet.spawn( + self._process_completion_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.completion_queue)) - w = eventlet.spawn(self._process_callback_queue) + w = eventlet.spawn( + self._process_callback_queue # type: ignore[no-untyped-call] + ) self._workers.append((w, self.callback_queue)) self._started = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: while self._workers: w, q = self._workers.pop() - q.put(_STOP) + q.put(_STOP) # type: ignore[no-untyped-call] w.wait() self._started = False atexit.unregister(self.stop) - def socket(self, *args, **kwargs): + def socket(self) -> Socket: return utils.create_tcp_socket(green_socket) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(green_socket) - def event_object(self): - return green_threading.Event() + def event_object(self) -> Event: + return cast( + "Event", green_threading.Event() # type: ignore[attr-defined] + ) - def lock_object(self): - return green_threading.Lock() + def lock_object(self) -> Lockable: + return cast( + "Lockable", green_threading.Lock() # type: ignore[attr-defined] + ) - def rlock_object(self): - return green_threading.RLock() + def rlock_object(self) -> ReentrantLock: + return cast( + "ReentrantLock", + green_threading.RLock(), # type: ignore[attr-defined] + ) - def create_connection(self, *args, **kwargs): + # FIXME fix parameters + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(green_socket, *args, **kwargs) - def select(self, *args, **kwargs): + # FIXME fix parameters + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: with _yield_before_after(): + # Following appears to be a bug in mypy (see + # https://github.com/python/mypy/issues/6799) return selector_select( - *args, selectors_module=green_selectors, **kwargs + *args, + selectors_module=green_selectors, # type: ignore[misc] + **kwargs, ) - def async_result(self): - return AsyncResult(self) + def async_result(self) -> AsyncResult: + return AsyncResult(self) # type: ignore[arg-type] - def spawn(self, func, *args, **kwargs): - t = green_threading.Thread(target=func, args=args, kwargs=kwargs) + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> green_threading.Thread: # type: ignore[name-defined] + t = green_threading.Thread( # type: ignore[attr-defined] + target=func, args=args, kwargs=kwargs + ) t.daemon = True t.start() return t - def dispatch_callback(self, callback): - self.callback_queue.put(lambda: callback.func(*callback.args)) + def dispatch_callback(self, callback: Callback) -> None: + self.callback_queue.put( # type: ignore[no-untyped-call] + lambda: callback.func(*callback.args) + ) diff --git a/kazoo/handlers/gevent.py b/kazoo/handlers/gevent.py index f36389aac..7a3c215cf 100644 --- a/kazoo/handlers/gevent.py +++ b/kazoo/handlers/gevent.py @@ -1,9 +1,13 @@ """A gevent based handler.""" + +from __future__ import annotations from __future__ import absolute_import import atexit import logging +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast + import gevent from gevent import socket import gevent.event @@ -14,18 +18,30 @@ from kazoo.handlers.utils import selector_select +from gevent import Greenlet from gevent.lock import Semaphore, RLock from kazoo.handlers import utils +if TYPE_CHECKING: + from kazoo.interfaces import FdLike, Lockable, Socket + from kazoo.protocol.states import Callback + _using_libevent = gevent.__version__.startswith("0.") log = logging.getLogger(__name__) + _STOP = object() AsyncResult = gevent.event.AsyncResult +# The following would be great typenames, but python3.8 complains about type +# objects not being indexable. +# GCallback = Callable[..., None] +# Worker = Greenlet[..., Any] +# CallbackQueue = gevent.queue.Queue[Callable[..., None]] + class SequentialGeventHandler(object): """Gevent handler for sequentially executing callbacks. @@ -53,24 +69,28 @@ class SequentialGeventHandler(object): queue_empty = gevent.queue.Empty sleep_func = staticmethod(gevent.sleep) - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialGeventHandler` instance""" - self.callback_queue = self.queue_impl() + self.callback_queue: gevent.queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._async = None self._state_change = Semaphore() - self._workers = [] + self._workers: list[Greenlet[..., Any]] = [] @property - def running(self): + def running(self) -> bool: return self._running class timeout_exception(gevent.Timeout): - def __init__(self, msg): + def __init__(self, msg: Any): gevent.Timeout.__init__(self, exception=msg) - def _create_greenlet_worker(self, queue): - def greenlet_worker(): + def _create_greenlet_worker( + self, queue: gevent.queue.Queue[Callable[..., None]] + ) -> Greenlet[..., Any]: + def greenlet_worker() -> None: while True: try: func = queue.get() @@ -88,7 +108,7 @@ def greenlet_worker(): return gevent.spawn(greenlet_worker) - def start(self): + def start(self) -> None: """Start the greenlet workers.""" with self._state_change: if self._running: @@ -98,12 +118,13 @@ def start(self): # Spawn our worker greenlets, we have # - A callback worker for watch events to be called + # FIXME Why the loop? for queue in (self.callback_queue,): w = self._create_greenlet_worker(queue) self._workers.append(w) atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the greenlet workers and empty all queues.""" with self._state_change: if not self._running: @@ -112,7 +133,7 @@ def stop(self): self._running = False for queue in (self.callback_queue,): - queue.put(_STOP) + queue.put(cast("Callable[..., None]", _STOP)) while self._workers: worker = self._workers.pop() @@ -123,33 +144,45 @@ def stop(self): atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: + # FIXME use the correct arguments, not *args, *kwargs return selector_select( - *args, selectors_module=gevent.selectors, **kwargs + # Likely a bug in mypy (see + # https://github.com/python/mypy/issues/6799) + *args, + selectors_module=gevent.selectors, + **kwargs, # type: ignore[misc] ) - def socket(self, *args, **kwargs): + def socket(self) -> Socket: + # See above return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + # See above return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> gevent.event.Event: """Create an appropriate Event object""" return gevent.event.Event() - def lock_object(self): + def lock_object(self) -> Lockable: """Create an appropriate Lock object""" - return gevent.thread.allocate_lock() + return cast( + "Lockable", + gevent.thread.allocate_lock(), # type: ignore[no-untyped-call] + ) - def rlock_object(self): + def rlock_object(self) -> RLock: """Create an appropriate RLock object""" return RLock() - def async_result(self): + def async_result(self) -> AsyncResult[Any]: """Create a :class:`AsyncResult` instance The :class:`AsyncResult` instance will have its completion @@ -160,11 +193,13 @@ def async_result(self): """ return AsyncResult() - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: Any, *args: Any, **kwargs: Any + ) -> gevent.Greenlet[..., Any]: """Spawn a function to run asynchronously""" return gevent.spawn(func, *args, **kwargs) - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/threading.py b/kazoo/handlers/threading.py index b9acd8756..829a70107 100644 --- a/kazoo/handlers/threading.py +++ b/kazoo/handlers/threading.py @@ -10,6 +10,8 @@ :class:`~kazoo.handlers.gevent.SequentialGeventHandler` instead. """ + +from __future__ import annotations from __future__ import absolute_import import atexit @@ -19,9 +21,22 @@ import threading import time +from typing import Any, Callable, Iterable, TYPE_CHECKING, cast + from kazoo.handlers import utils from kazoo.handlers.utils import selector_select - +from kazoo.interfaces import IHandler + +if TYPE_CHECKING: + from kazoo.interfaces import ( + Event, + FdLike, + Lockable, + ReentrantLock, + Socket, + SpawnedFunc, + ) + from kazoo.protocol.states import Callback # sentinel objects _STOP = object() @@ -29,7 +44,7 @@ log = logging.getLogger(__name__) -def _to_fileno(obj): +def _to_fileno(obj: FdLike) -> int: if isinstance(obj, int): fd = int(obj) elif hasattr(obj, "fileno"): @@ -55,13 +70,13 @@ class KazooTimeoutError(Exception): class AsyncResult(utils.AsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler): + def __init__(self, handler: Any) -> None: super(AsyncResult, self).__init__( handler, threading.Condition, KazooTimeoutError ) -class SequentialThreadingHandler(object): +class SequentialThreadingHandler(IHandler): """Threading handler for sequentially executing callbacks. This handler executes callbacks in a sequential manner. A queue is @@ -96,20 +111,26 @@ class SequentialThreadingHandler(object): queue_impl = queue.Queue queue_empty = queue.Empty - def __init__(self): + def __init__(self) -> None: """Create a :class:`SequentialThreadingHandler` instance""" - self.callback_queue = self.queue_impl() - self.completion_queue = self.queue_impl() + self.callback_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() + self.completion_queue: queue.Queue[ + Callable[..., None] + ] = self.queue_impl() self._running = False self._state_change = threading.Lock() - self._workers = [] + self._workers: list[threading.Thread] = [] @property - def running(self): + def running(self) -> bool: return self._running - def _create_thread_worker(self, work_queue): - def _thread_worker(): # pragma: nocover + def _create_thread_worker( + self, work_queue: queue.Queue[Callable[..., None]] + ) -> threading.Thread: + def _thread_worker() -> None: # pragma: nocover while True: try: func = work_queue.get() @@ -128,7 +149,7 @@ def _thread_worker(): # pragma: nocover t = self.spawn(_thread_worker) return t - def start(self): + def start(self) -> None: """Start the worker threads.""" with self._state_change: if self._running: @@ -143,7 +164,7 @@ def start(self): self._running = True atexit.register(self.stop) - def stop(self): + def stop(self) -> None: """Stop the worker threads and empty all queues.""" with self._state_change: if not self._running: @@ -152,7 +173,7 @@ def stop(self): self._running = False for work_queue in (self.completion_queue, self.callback_queue): - work_queue.put(_STOP) + work_queue.put(cast("Callable[..., None]", _STOP)) self._workers.reverse() while self._workers: @@ -164,41 +185,47 @@ def stop(self): self.completion_queue = self.queue_impl() atexit.unregister(self.stop) - def select(self, *args, **kwargs): + def select( + self, *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: return selector_select(*args, **kwargs) - def socket(self): + def socket(self) -> Socket: return utils.create_tcp_socket(socket) - def create_connection(self, *args, **kwargs): + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: return utils.create_tcp_connection(socket, *args, **kwargs) - def create_socket_pair(self): + def create_socket_pair(self) -> tuple[Socket, Socket]: return utils.create_socket_pair(socket) - def event_object(self): + def event_object(self) -> Event: """Create an appropriate Event object""" return threading.Event() - def lock_object(self): + def lock_object(self) -> Lockable: """Create a lock object""" - return threading.Lock() + # Note: This is not ideal, but the ContextManager Protocol seems to + # think you should return an object of the same type. + return cast("Lockable", threading.Lock()) - def rlock_object(self): + def rlock_object(self) -> ReentrantLock: """Create an appropriate RLock object""" - return threading.RLock() + return cast("ReentrantLock", threading.RLock()) - def async_result(self): + def async_result(self) -> AsyncResult: """Create a :class:`AsyncResult` instance""" return AsyncResult(self) - def spawn(self, func, *args, **kwargs): + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> threading.Thread: t = threading.Thread(target=func, args=args, kwargs=kwargs) t.daemon = True t.start() return t - def dispatch_callback(self, callback): + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object The callback is put on separate queues to run depending on the diff --git a/kazoo/handlers/utils.py b/kazoo/handlers/utils.py index 206806f6a..019cc2d22 100644 --- a/kazoo/handlers/utils.py +++ b/kazoo/handlers/utils.py @@ -1,5 +1,7 @@ """Kazoo handler helpers""" +from __future__ import annotations + from collections import defaultdict import errno import functools @@ -8,6 +10,14 @@ import ssl import socket import time +from types import ModuleType +from typing import Any, Callable, Iterable, TypeVar, TYPE_CHECKING +from typing_extensions import ParamSpec + +from kazoo.interfaces import IAsyncResult, FdLike + +if TYPE_CHECKING: + from kazoo.interfaces import Socket HAS_FNCTL = True try: @@ -15,36 +25,51 @@ except ImportError: # pragma: nocover HAS_FNCTL = False + # sentinel objects +# Note: This needs to be a unique object that is not None, as None is used to +# indicate a successful result in AsyncResult. +# This should probably be an Enum, it would certainly be cleaner, but don't +# want to change the code too much. _NONE = object() +CallbackFunc = Callable[..., None] -class AsyncResult(object): + +class AsyncResult(IAsyncResult): """A one-time event that stores a value or an exception""" - def __init__(self, handler, condition_factory, timeout_factory): + def __init__( + self, + handler: Any, + condition_factory: Callable[[], Any], + timeout_factory: Callable[[], Any], + ) -> None: self._handler = handler - self._exception = _NONE + self._exception: object | Exception | None = _NONE self._condition = condition_factory() - self._callbacks = [] + self._callbacks: list[CallbackFunc] = [] self._timeout_factory = timeout_factory self.value = None - def ready(self): + def ready(self) -> bool: """Return true if and only if it holds a value or an exception""" return self._exception is not _NONE - def successful(self): + def successful(self) -> bool: """Return true if and only if it is ready and holds a value""" return self._exception is None @property - def exception(self): + def exception(self) -> Exception | None: if self._exception is not _NONE: - return self._exception + # The next line should have return-value, but hound ci + # is frankly nothing but a hound dog + return self._exception # type: ignore + return None - def set(self, value=None): + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters.""" with self._condition: self.value = value @@ -52,14 +77,14 @@ def set(self, value=None): self._do_callbacks() self._condition.notify_all() - def set_exception(self, exception): + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters.""" with self._condition: self._exception = exception self._do_callbacks() self._condition.notify_all() - def get(self, block=True, timeout=None): + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception. If there is no value raises TimeoutError. @@ -69,18 +94,18 @@ def get(self, block=True, timeout=None): if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] elif block: self._condition.wait(timeout) if self._exception is not _NONE: if self._exception is None: return self.value - raise self._exception + raise self._exception # type: ignore[misc] # if we get to this point we timeout raise self._timeout_factory() - def get_nowait(self): + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raises TimeoutError @@ -88,14 +113,14 @@ def get_nowait(self): """ return self.get(block=False) - def wait(self, timeout=None): + def wait(self, timeout: float | None = None) -> bool: """Block until the instance is ready.""" with self._condition: if not self.ready(): self._condition.wait(timeout) return self._exception is not _NONE - def rawlink(self, callback): + def rawlink(self, callback: CallbackFunc) -> None: """Register a callback to call when a value or an exception is set""" with self._condition: @@ -106,7 +131,7 @@ def rawlink(self, callback): if self.ready(): self._do_callbacks() - def unlink(self, callback): + def unlink(self, callback: CallbackFunc) -> None: """Remove the callback set by :meth:`rawlink`""" with self._condition: if self.ready(): @@ -116,7 +141,7 @@ def unlink(self, callback): if callback in self._callbacks: self._callbacks.remove(callback) - def _do_callbacks(self): + def _do_callbacks(self) -> None: """Execute the callbacks that were registered by :meth:`rawlink`. If the handler is in running state this method only schedules the calls to be performed by the handler. If it's stopped, @@ -131,19 +156,21 @@ def _do_callbacks(self): functools.partial(callback, self)() -def _set_fd_cloexec(fd): +def _set_fd_cloexec(fd: Socket) -> None: flags = fcntl.fcntl(fd, fcntl.F_GETFD) fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) -def _set_default_tcpsock_options(module, sock): +def _set_default_tcpsock_options(module: ModuleType, sock: Socket) -> Socket: sock.setsockopt(module.IPPROTO_TCP, module.TCP_NODELAY, 1) if HAS_FNCTL: _set_fd_cloexec(sock) return sock -def create_socket_pair(module, port=0): +def create_socket_pair( + module: ModuleType, port: int = 0 +) -> tuple[Socket, Socket]: """Create socket pair. If socket.socketpair isn't available, we emulate it. @@ -182,32 +209,32 @@ def create_socket_pair(module, port=0): return client_sock, srv_sock -def create_tcp_socket(module): +def create_tcp_socket(module: ModuleType) -> Socket: """Create a TCP socket with the CLOEXEC flag set.""" type_ = module.SOCK_STREAM if hasattr(module, "SOCK_CLOEXEC"): # pragma: nocover # if available, set cloexec flag during socket creation type_ |= module.SOCK_CLOEXEC - sock = module.socket(module.AF_INET, type_) + sock: Socket = module.socket(module.AF_INET, type_) _set_default_tcpsock_options(module, sock) return sock def create_tcp_connection( - module, - address, - hostname=None, - timeout=None, - use_ssl=False, - ca=None, - certfile=None, - keyfile=None, - keyfile_password=None, - verify_certs=True, - check_hostname=False, - options=None, - ciphers=None, -): + module: ModuleType, + address: tuple[str, str | int], + hostname: str | None = None, + timeout: float | None = None, + use_ssl: bool = False, + ca: str | None = None, + certfile: str | None = None, + keyfile: str | None = None, + keyfile_password: str | None = None, + verify_certs: bool = True, + check_hostname: bool = False, + options: ssl.Options | None = None, + ciphers: str | None = None, +) -> Socket: end = None if timeout is None: # thanks to create_connection() developers for @@ -215,7 +242,7 @@ def create_tcp_connection( timeout = module.getdefaulttimeout() if timeout is not None: end = time.monotonic() + timeout - sock = None + sock: Socket | None = None while True: timeout_at = end if end is None else end - time.monotonic() @@ -279,7 +306,13 @@ def create_tcp_connection( sock = module.create_connection(address, timeout_at) break except Exception as ex: - errnum = ex.errno if isinstance(ex, OSError) else ex[0] + # Seriously WTF? if ex is an exception, how can it be a tuple? + # I guess gevent can do this, but really... + errnum = ( + ex.errno + if isinstance(ex, OSError) + else ex[0] # type: ignore + ) if errnum == errno.EINTR: continue raise @@ -291,7 +324,16 @@ def create_tcp_connection( return sock -def capture_exceptions(async_result): +CapturedResult = TypeVar("CapturedResult") +GenericArgs = ParamSpec("GenericArgs") + + +def capture_exceptions( + async_result: IAsyncResult, +) -> Callable[ + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], +]: """Return a new decorated function that propagates the exceptions of the wrapped function to an async_result. @@ -299,20 +341,30 @@ def capture_exceptions(async_result): """ - def capture(function): + def capture( + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @functools.wraps(function) - def captured_function(*args, **kwargs): + def captured_function( + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs + ) -> CapturedResult | None: try: return function(*args, **kwargs) except Exception as exc: async_result.set_exception(exc) + return None return captured_function return capture -def wrap(async_result): +def wrap( + async_result: IAsyncResult, +) -> Callable[ + [Callable[GenericArgs, CapturedResult]], + Callable[GenericArgs, CapturedResult | None], +]: """Return a new decorated function that propagates the return value or exception of wrapped function to an async_result. NOTE: Only propagates a non-None return value. @@ -321,9 +373,13 @@ def wrap(async_result): """ - def capture(function): + def capture( + function: Callable[GenericArgs, CapturedResult] + ) -> Callable[GenericArgs, CapturedResult | None]: @capture_exceptions(async_result) - def captured_function(*args, **kwargs): + def captured_function( + *args: GenericArgs.args, **kwargs: GenericArgs.kwargs + ) -> CapturedResult | None: value = function(*args, **kwargs) if value is not None: async_result.set(value) @@ -334,7 +390,7 @@ def captured_function(*args, **kwargs): return capture -def fileobj_to_fd(fileobj): +def fileobj_to_fd(fileobj: FdLike) -> int: """Return a file descriptor from a file object. Parameters: @@ -349,18 +405,25 @@ def fileobj_to_fd(fileobj): if isinstance(fileobj, int): fd = fileobj else: + # FIXME given the protocol I don't think the try/catch/int are + # required. try: fd = int(fileobj.fileno()) except (AttributeError, TypeError, ValueError): raise TypeError("Invalid file object: " "{!r}".format(fileobj)) + # FIXME Questionable, just let select deal with it. if fd < 0: raise TypeError("Invalid file descriptor: {}".format(fd)) return fd def selector_select( - rlist, wlist, xlist, timeout=None, selectors_module=selectors -): + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], + timeout: float | None = None, + selectors_module: ModuleType = selectors, +) -> tuple[list[FdLike], list[FdLike], list[FdLike]]: """Selector-based drop-in replacement for select to overcome select limitation on a maximum filehandle value. """ @@ -374,8 +437,8 @@ def selector_select( selectors_module.EVENT_READ: rlist, selectors_module.EVENT_WRITE: wlist, } - fd_events = defaultdict(int) - fd_fileobjs = defaultdict(list) + fd_events: defaultdict[int, int] = defaultdict(int) + fd_fileobjs: defaultdict[int, list[FdLike]] = defaultdict(list) for event, fileobjs in events_mapping.items(): for fileobj in fileobjs: @@ -391,7 +454,9 @@ def selector_select( # gevent can raise OSError raise ValueError("Invalid event mask or fd") from e - revents, wevents, xevents = [], [], [] + revents: list[FdLike] = [] + wevents: list[FdLike] = [] + xevents: list[FdLike] = [] try: ready = selector.select(timeout) finally: diff --git a/kazoo/hosts.py b/kazoo/hosts.py index 3ece93180..cda746a3a 100644 --- a/kazoo/hosts.py +++ b/kazoo/hosts.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import urllib.parse -def collect_hosts(hosts): +def collect_hosts( + hosts: str | list[str], +) -> tuple[list[tuple[str, int]], str | None]: """ Collect a set of hosts and an optional chroot from a string or a list of strings. @@ -12,8 +16,8 @@ def collect_hosts(hosts): else: host_ports, chroot = hosts, None else: - host_ports, chroot = hosts.partition("/")[::2] - host_ports = host_ports.split(",") + host_ports_1, chroot = hosts.partition("/")[::2] + host_ports = host_ports_1.split(",") chroot = "/" + chroot if chroot else None result = [] diff --git a/kazoo/interfaces.py b/kazoo/interfaces.py index 351f1fd89..18688ee57 100644 --- a/kazoo/interfaces.py +++ b/kazoo/interfaces.py @@ -8,10 +8,160 @@ """ +from __future__ import annotations + +import abc +import queue + +from types import TracebackType +from typing import ( + Any, + Callable, + Iterable, + Protocol, + Union, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from kazoo.protocol.states import Callback + # public API -class IHandler(object): +class HasFileNo(Protocol): + """Protocol for objects that support a fileno method.""" + + def fileno(self) -> int: + ... + + +FdLike = Union[int, HasFileNo] + + +class Socket(HasFileNo, Protocol): + """This is for things that provide a socket.socket-like interface. + + This is required because: + 1. The socket in gevent doesn't inherit from socket.socket + 2. mypy gets confused if you have a method called socket and + subsequently attempt to use socket or socket.socket as a return type + """ + + def close(self) -> None: + ... + + def fileno(self) -> int: + ... + + def getpeername(self) -> tuple[str, int]: + ... + + def getsockname(self) -> tuple[str, int]: + ... + + def recv(self, bufsize: int, flags: int = 0) -> bytes: + ... + + def send(self, data: bytes | memoryview, flags: int = 0) -> int: + ... + + def sendall(self, data: bytes, flags: int = 0) -> None: + ... + + def setblocking(self, flags: bool) -> None: + ... + + def setsockopt(self, level: int, optname: int, value: int) -> None: + ... + + def shutdown(self, flag: int) -> None: + ... + + +class Lockable(Protocol): + """This is what threading.Lock implements. + + In python 3.9+ it's available natively. Though given it has some + very odd typing, I wouldn't put money on it. + """ + + def __enter__(self) -> None: + ... + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> int | None: + """The gevent release returns an int...""" + ... + + def locked(self) -> bool: + ... + + +class ReentrantLock(Protocol): + """This is what threading.RLock implements. + + In python 3.14+, it's the same as Lock, which adds to the fun. + """ + + def __enter__(self) -> None: + ... + + def __exit__( + self, + type_: type[BaseException] | None, + value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: + ... + + def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: + ... + + def release(self) -> None: + ... + + +class Event(Protocol): + """Protocol for threading.Event""" + + def is_set(self) -> bool: + ... + + def set(self) -> None: + ... + + def clear(self) -> None: + ... + + def wait(self, timeout: float | None = None) -> bool: + ... + + +class Threadlike(Protocol): + """Protocol for something like a thread.""" + + def is_alive(self) -> bool: + ... + + def join(self, timeout: float | None = None) -> None: + ... + + +SpawnedFunc = Callable[..., None] + + +class IHandler(abc.ABC): """A Callback Handler for Zookeeper completion and watch callbacks. This object must implement several methods responsible for @@ -44,43 +194,69 @@ class IHandler(object): """ - def start(self): + timeout_exception: type[Exception] = None # type: ignore[assignment] + sleep_func: staticmethod[[float], None] = None # type: ignore[assignment] + queue_impl: type[queue.Queue[Any]] = None # type: ignore[assignment] + + @abc.abstractmethod + def start(self) -> None: """Start the handler, used for setting up the handler.""" - def stop(self): + @abc.abstractmethod + def stop(self) -> None: """Stop the handler. Should block until the handler is safely stopped.""" - def select(self): + @abc.abstractmethod + def select( + self, + rlist: Iterable[FdLike], + wlist: Iterable[FdLike], + xlist: Iterable[FdLike], + timeout: float | None = None, + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: """A select method that implements Python's select.select API""" - def socket(self): - """A socket method that implements Python's socket.socket + @abc.abstractmethod + def socket(self) -> Socket: + """A socket method that implements Python's socket.socket API""" + + # FIXME This should have a proper set of parameters. + @abc.abstractmethod + def create_connection(self, *args: Any, **kwargs: Any) -> Socket: + """A socket method that implements Python's socket.create_connection API""" - def create_connection(self): - """A socket method that implements Python's - socket.create_connection API""" + @abc.abstractmethod + def create_socket_pair(self) -> tuple[Socket, Socket]: + """A socket method that implements Python's socket.socketpair API""" - def event_object(self): + @abc.abstractmethod + def event_object(self) -> Event: """Return an appropriate object that implements Python's threading.Event API""" - def lock_object(self): + @abc.abstractmethod + def lock_object(self) -> Lockable: """Return an appropriate object that implements Python's threading.Lock API""" - def rlock_object(self): + @abc.abstractmethod + def rlock_object(self) -> ReentrantLock: """Return an appropriate object that implements Python's threading.RLock API""" - def async_result(self): + @abc.abstractmethod + def async_result(self) -> IAsyncResult: """Return an instance that conforms to the :class:`~IAsyncResult` interface appropriate for this handler""" - def spawn(self, func, *args, **kwargs): + @abc.abstractmethod + def spawn( + self, func: SpawnedFunc, *args: Any, **kwargs: Any + ) -> Threadlike: """Spawn a function to run asynchronously :param args: args to call the function with. @@ -91,7 +267,8 @@ def spawn(self, func, *args, **kwargs): """ - def dispatch_callback(self, callback): + @abc.abstractmethod + def dispatch_callback(self, callback: Callback) -> None: """Dispatch to the callback object :param callback: A :class:`~kazoo.protocol.states.Callback` @@ -100,7 +277,7 @@ def dispatch_callback(self, callback): """ -class IAsyncResult(object): +class IAsyncResult(abc.ABC): """An Async Result object that can be queried for a value that has been set asynchronously. @@ -123,15 +300,18 @@ class IAsyncResult(object): """ - def ready(self): + @abc.abstractmethod + def ready(self) -> bool: """Return `True` if and only if it holds a value or an exception""" - def successful(self): + @abc.abstractmethod + def successful(self) -> bool: """Return `True` if and only if it is ready and holds a value""" - def set(self, value=None): + @abc.abstractmethod + def set(self, value: Any = None) -> None: """Store the value. Wake up the waiters. :param value: Value to store as the result. @@ -140,7 +320,8 @@ def set(self, value=None): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def set_exception(self, exception): + @abc.abstractmethod + def set_exception(self, exception: Exception) -> None: """Store the exception. Wake up the waiters. :param exception: Exception to raise when fetching the value. @@ -149,7 +330,8 @@ def set_exception(self, exception): up. Sequential calls to :meth:`wait` and :meth:`get` will not block at all.""" - def get(self, block=True, timeout=None): + @abc.abstractmethod + def get(self, block: bool = True, timeout: float | None = None) -> Any: """Return the stored value or raise the exception :param block: Whether this method should block or return @@ -164,13 +346,15 @@ def get(self, block=True, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def get_nowait(self): + @abc.abstractmethod + def get_nowait(self) -> Any: """Return the value or raise the exception without blocking. If nothing is available, raise the Timeout exception class on the associated :class:`IHandler` interface.""" - def wait(self, timeout=None): + @abc.abstractmethod + def wait(self, timeout: float | None = None) -> Any: """Block until the instance is ready. :param timeout: How long to wait for a value when `block` is @@ -182,7 +366,8 @@ def wait(self, timeout=None): :meth:`set_exception` has been called or until the optional timeout occurs.""" - def rawlink(self, callback): + @abc.abstractmethod + def rawlink(self, callback: Callable[[IAsyncResult], Any]) -> None: """Register a callback to call when a value or an exception is set @@ -194,10 +379,17 @@ def rawlink(self, callback): """ - def unlink(self, callback): + @abc.abstractmethod + def unlink(self, callback: Callable[[IAsyncResult], None]) -> None: """Remove the callback set by :meth:`rawlink` :param callback: A callback function to remove. :type callback: func """ + + @property + @abc.abstractmethod + def exception(self) -> Exception | None: + """The exception set by :meth:`set_exception` or `None` if no + exception has been set""" diff --git a/kazoo/protocol/connection.py b/kazoo/protocol/connection.py index 3df7b1626..e4ed1a17d 100644 --- a/kazoo/protocol/connection.py +++ b/kazoo/protocol/connection.py @@ -1,4 +1,7 @@ """Zookeeper Protocol Connection Handler""" + +from __future__ import annotations + from binascii import hexlify from contextlib import contextmanager import copy @@ -8,6 +11,17 @@ import socket import ssl import time +from typing import ( + Callable, + ContextManager, + Iterator, + Literal, + TypeVar, + TYPE_CHECKING, + cast, + overload, +) +from typing_extensions import Buffer from kazoo.exceptions import ( AuthFailedError, @@ -41,9 +55,14 @@ ) from kazoo.retry import ( ForceRetryError, + KazooRetry, RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import Socket, Threadlike + try: import puresasl import puresasl.client @@ -76,7 +95,7 @@ # removed from Python3+ -def buffer(obj, offset=0): +def buffer(obj: Buffer, offset: int = 0) -> memoryview: return memoryview(obj)[offset:] @@ -94,22 +113,32 @@ class RWPinger(object): """ - def __init__(self, hosts, connection_func, socket_handling): + def __init__( + self, + hosts: list[tuple[str, int]], + connection_func: Callable[[tuple[str, int]], Socket], + socket_handling: Callable[[], ContextManager[None]], + ): self.hosts = hosts self.connection = connection_func - self.last_attempt = None + self.last_attempt: float | None = None self.socket_handling = socket_handling - def __iter__(self): + def __iter__(self) -> Iterator[tuple[str, int] | Literal[False] | None]: if not self.last_attempt: self.last_attempt = time.monotonic() delay = 0.5 while True: yield self._next_server(delay) - def _next_server(self, delay): + def _next_server( + self, delay: float + ) -> tuple[str, int] | Literal[False] | None: jitter = random.randint(0, 100) / 100.0 - while time.monotonic() < self.last_attempt + delay + jitter: + while ( + time.monotonic() + < self.last_attempt + delay + jitter # type: ignore[operator] + ): # Skip rw ping checks if its too soon return False for host, port in self.hosts: @@ -128,20 +157,36 @@ def _next_server(self, delay): except ConnectionDropped: return False + # NOTE: This does actually look like it's unreachable but I don't + # want to alter the code any more than necessary for the first + # pass. + # The loop is basically a sleep with jitter that can be # Add some jitter between host pings - while time.monotonic() < self.last_attempt + jitter: + while ( # type: ignore[unreachable] + time.monotonic() < self.last_attempt + jitter + ): return False - delay *= 2 + delay *= 2 # And while not unreachable, this is pointless + return None class RWServerAvailable(Exception): """Thrown if a RW Server becomes available""" +ReturnValue = TypeVar("ReturnValue") + + class ConnectionHandler(object): """Zookeeper connection handler""" - def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): + def __init__( + self, + client: KazooClient, + retry_sleeper: KazooRetry, + logger: logging.Logger | None = None, + sasl_options: dict[str, str] | None = None, + ): self.client = client self.handler = client.handler self.retry_sleeper = retry_sleeper @@ -154,15 +199,17 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): self.connection_stopped.set() self.ping_outstanding = client.handler.event_object() - self._read_sock = None - self._write_sock = None + self._read_sock: Socket | None = None + self._write_sock: Socket | None = None - self._socket = None - self._xid = None - self._rw_server = None - self._ro_mode = False + self._socket: Socket | None = None + self._xid: int | None = None + self._rw_server: tuple[str, int] | None = None + self._ro_mode: Iterator[ + Literal[False] | tuple[str, int] | None + ] | Literal[False] | None = False - self._connection_routine = None + self._connection_routine: Threadlike | None = None self.sasl_options = sasl_options self.sasl_cli = None @@ -170,14 +217,14 @@ def __init__(self, client, retry_sleeper, logger=None, sasl_options=None): # This is instance specific to avoid odd thread bug issues in Python # during shutdown global cleanup @contextmanager - def _socket_error_handling(self): + def _socket_error_handling(self) -> Iterator[None]: try: yield except (socket.error, select.error) as e: err = getattr(e, "strerror", e) raise ConnectionDropped("socket connection error: %s" % (err,)) - def start(self): + def start(self) -> None: """Start the connection up""" if self.connection_closed.is_set(): rw_sockets = self.handler.create_socket_pair() @@ -189,7 +236,7 @@ def start(self): ) self._connection_routine = self.handler.spawn(self.zk_loop) - def stop(self, timeout=None): + def stop(self, timeout: float | None = None) -> bool: """Ensure the writer has stopped, wait to see if it does.""" self.connection_stopped.wait(timeout) if self._connection_routine: @@ -197,7 +244,7 @@ def stop(self, timeout=None): self._connection_routine = None return self.connection_stopped.is_set() - def close(self): + def close(self) -> None: """Release resources held by the connection The connection can be restarted afterwards. @@ -212,7 +259,7 @@ def close(self): if rs is not None: rs.close() - def _server_pinger(self): + def _server_pinger(self) -> RWPinger: """Returns a server pinger iterable, that will ping the next server in the list, and apply a back-off between attempts.""" return RWPinger( @@ -221,16 +268,21 @@ def _server_pinger(self): self._socket_error_handling, ) - def _read_header(self, timeout): + def _read_header( + self, timeout: float | None + ) -> tuple[ReplyHeader, bytes, int]: b = self._read(4, timeout) length = int_struct.unpack(b)[0] b = self._read(length, timeout) header, offset = ReplyHeader.deserialize(b, 0) return header, b, offset - def _read(self, length, timeout): + def _read(self, length: int, timeout: float | None) -> bytes: msgparts = [] remaining = length + # We know that self._socket is not None here because we only call + # this method when we have set up the connection and the read socket. + # But mypy doesn't understand that. with self._socket_error_handling(): while remaining > 0: # Because of SSL framing, a select may not return when using @@ -240,11 +292,13 @@ def _read(self, length, timeout): # data from the underlying socket. if ( hasattr(self._socket, "pending") - and self._socket.pending() > 0 + and self._socket.pending() > 0 # type: ignore[union-attr] ): pass else: - s = self.handler.select([self._socket], [], [], timeout)[0] + s = self.handler.select( + [cast("Socket", self._socket)], [], [], timeout + )[0] if not s: # pragma: nocover # If the read list is empty, we got a timeout. We don't # have to check wlist and xlist as we don't set any @@ -252,7 +306,9 @@ def _read(self, length, timeout): "socket time-out during read" ) try: - chunk = self._socket.recv(remaining) + chunk = self._socket.recv( # type: ignore[union-attr] + remaining + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -267,7 +323,24 @@ def _read(self, length, timeout): remaining -= len(chunk) return b"".join(msgparts) - def _invoke(self, timeout, request, xid=None): + @overload + def _invoke( + self, timeout: float | None, request: Connect + ) -> tuple[Connect, int | None]: + ... + + @overload + def _invoke( + self, timeout: float | None, request: Auth, xid: int + ) -> int | None: + ... + + def _invoke( + self, + timeout: float | None, + request: Auth | Connect, + xid: int | None = None, + ) -> tuple[Connect, int | None] | int | None: """A special writer used during connection establishment only""" self._submit(request, timeout, xid) @@ -296,7 +369,9 @@ def _invoke(self, timeout, request, xid=None): if hasattr(request, "deserialize"): try: - obj, _ = request.deserialize(msg, 0) + # This is a bit of an annoying ignore as I've just done a + # hasattr... + obj, _ = request.deserialize(msg, 0) # type:ignore[union-attr] except Exception: self.logger.exception( "Exception raised during deserialization " @@ -311,7 +386,12 @@ def _invoke(self, timeout, request, xid=None): return zxid - def _submit(self, request, timeout, xid=None): + def _submit( + self, + request: Auth | Connect | Ping | SASL, + timeout: float | None, + xid: int | None = None, + ) -> None: """Submit a request object with a timeout value and optional xid""" b = bytearray() @@ -328,13 +408,17 @@ def _submit(self, request, timeout, xid=None): ) self._write(int_struct.pack(len(b)) + b, timeout) - def _write(self, msg, timeout): + def _write(self, msg: bytes, timeout: float | None) -> None: """Write a raw msg to the socket""" sent = 0 msg_length = len(msg) + # Note: The casts/type: ignore are because mypy can't work out + # self._socket is not None, and I don't want to change any code. with self._socket_error_handling(): while sent < msg_length: - s = self.handler.select([], [self._socket], [], timeout)[1] + s = self.handler.select( + [], [cast("Socket", self._socket)], [], timeout + )[1] if not s: # pragma: nocover # If the write list is empty, we got a timeout. We don't # have to check rlist and xlist as we don't set any @@ -343,7 +427,9 @@ def _write(self, msg, timeout): ) msg_slice = buffer(msg, sent) try: - bytes_sent = self._socket.send(msg_slice) + bytes_sent = self._socket.send( # type:ignore[union-attr] + msg_slice + ) except ssl.SSLError as e: if e.errno in ( ssl.SSL_ERROR_WANT_READ, @@ -356,14 +442,14 @@ def _write(self, msg, timeout): raise ConnectionDropped("socket connection broken") sent += bytes_sent - def _read_watch_event(self, buffer, offset): + def _read_watch_event(self, buffer: bytes, offset: int) -> None: client = self.client watch, offset = Watch.deserialize(buffer, offset) path = watch.path self.logger.debug("Received EVENT: %s", watch) - watchers = [] + watchers: list[WatchFunc] = [] if watch.type in (CREATED_EVENT, CHANGED_EVENT): watchers.extend(client._data_watchers.pop(path, [])) @@ -385,10 +471,15 @@ def _read_watch_event(self, buffer, offset): return # Dump the watchers to the watch thread - for watch in watchers: - client.handler.dispatch_callback(Callback("watch", watch, (ev,))) - - def _read_response(self, header, buffer, offset): + for watch1 in watchers: + client.handler.dispatch_callback(Callback("watch", watch1, (ev,))) + + def _read_response( + self, + header: ReplyHeader, + buffer: bytes, + offset: int, + ) -> object | None: client = self.client request, async_object, xid = client._pending.popleft() if header.zxid and header.zxid > 0: @@ -404,7 +495,11 @@ def _read_response(self, header, buffer, offset): # Determine if its an exists request and a no node error exists_error = ( - header.err == NoNodeError.code and request.type == Exists.type + # NoNodeError does actually have a code. It's added by a decorator, + # which could possibly be better done via inheritance but this is + # less invasive to the existing code. + header.err == NoNodeError.code # type: ignore[attr-defined] + and request.type == Exists.type ) # Set the exception if its not an exists error @@ -430,7 +525,7 @@ def _read_response(self, header, buffer, offset): request, ) async_object.set_exception(exc) - return + return None self.logger.debug( "Received response(xid=%s): %r", xid, response ) @@ -452,8 +547,9 @@ def _read_response(self, header, buffer, offset): if isinstance(request, Close): self.logger.log(BLATHER, "Read close response") return CLOSE_RESPONSE + return None - def _read_socket(self, read_timeout): + def _read_socket(self, read_timeout: float) -> object | None: """Called when there's something to read on the socket""" client = self.client @@ -476,8 +572,13 @@ def _read_socket(self, read_timeout): self.logger.log(BLATHER, "Reading for header %r", header) return self._read_response(header, buffer, offset) + return None - def _send_request(self, read_timeout, connect_timeout): + def _send_request( + self, + read_timeout: float, + connect_timeout: float, + ) -> None: """Called when we have something to send out on the socket""" client = self.client try: @@ -489,7 +590,11 @@ def _send_request(self, read_timeout, connect_timeout): try: # Clear possible inconsistence (no request in the queue # but have data in the read socket), which causes cpu to spin. - self._read_sock.recv(1) + # + # We know _read_sock is not None because we only call this + # method when we have set up the connection and the read + # socket, but mypy doesn't understand that. + self._read_sock.recv(1) # type: ignore[union-attr] except OSError: pass return @@ -505,15 +610,19 @@ def _send_request(self, read_timeout, connect_timeout): if request.type == Auth.type: xid = AUTH_XID else: - self._xid = (self._xid % 2147483647) + 1 + # We must have initialised the xid counter by now + # Might want to consider initialising it to 0 instead of none? + self._xid = (self._xid % 2147483647) + 1 # type: ignore[operator] xid = self._xid self._submit(request, connect_timeout, xid) client._queue.popleft() - self._read_sock.recv(1) + # _read_sock should never be None here as we only call this method + # when we have set up the connection and the read socket. + self._read_sock.recv(1) # type: ignore[union-attr] client._pending.append((request, async_object, xid)) - def _send_ping(self, connect_timeout): + def _send_ping(self, connect_timeout: float) -> None: self.ping_outstanding.set() self._submit(PingInstance, connect_timeout, PING_XID) @@ -524,7 +633,7 @@ def _send_ping(self, connect_timeout): self._rw_server = result raise RWServerAvailable() - def zk_loop(self): + def zk_loop(self) -> None: """Main Zookeeper handling loop""" self.logger.log(BLATHER, "ZK loop started") @@ -546,16 +655,28 @@ def zk_loop(self): self.client._session_callback(KeeperState.CLOSED) self.logger.log(BLATHER, "Connection stopped") - def _expand_client_hosts(self): + def _expand_client_hosts(self) -> list[tuple[str, str, int]]: # Expand the entire list in advance so we can randomize it if needed - host_ports = [] + host_ports: list[tuple[str, str, int]] = [] for host, port in self.client.hosts: try: host = host.strip() for rhost in socket.getaddrinfo( host, port, 0, 0, socket.IPPROTO_TCP ): - host_ports.append((host, rhost[4][0], rhost[4][1])) + # FIXME These casts seem to be unnecessary on later + # versions of mypy/python + host_ports.append( + ( + host, + cast( # type: ignore[redundant-cast] + "str", rhost[4][0] + ), + cast( # type: ignore[redundant-cast] + "int", rhost[4][1] + ), + ) + ) except socket.gaierror as e: # Skip hosts that don't resolve self.logger.warning("Cannot resolve %s: %s", host, e) @@ -564,7 +685,7 @@ def _expand_client_hosts(self): random.shuffle(host_ports) return host_ports - def _connect_loop(self, retry): + def _connect_loop(self, retry: KazooRetry) -> object: # Iterate through the hosts a full cycle before starting over status = None host_ports = self._expand_client_hosts() @@ -586,7 +707,13 @@ def _connect_loop(self, retry): else: raise ForceRetryError("Reconnecting") - def _connect_attempt(self, host, hostip, port, retry): + def _connect_attempt( + self, + host: str, + hostip: str, + port: int, + retry: KazooRetry, + ) -> object: client = self.client KazooTimeoutError = self.handler.timeout_exception @@ -606,6 +733,9 @@ def _connect_attempt(self, host, hostip, port, retry): try: self._xid = 0 read_timeout, connect_timeout = self._connect(host, hostip, port) + # I think the above implies self._socket can't be none, and + # self._read_sock is set up in start but mypy can't tell that. + # Hence the casting. read_timeout = read_timeout / 1000.0 connect_timeout = connect_timeout / 1000.0 retry.reset() @@ -619,7 +749,14 @@ def _connect_attempt(self, host, hostip, port, retry): # Ensure our timeout is positive timeout = max([deadline - time.monotonic(), jitter_time]) s = self.handler.select( - [self._socket, self._read_sock], [], [], timeout + [ + # FIXME we should know these aren't None + cast("Socket", self._socket), + cast("Socket", self._read_sock), + ], + [], + [], + timeout, )[0] if not s: @@ -629,14 +766,14 @@ def _connect_attempt(self, host, hostip, port, retry): "outstanding heartbeat ping not received" ) else: - if self._socket in s: + if cast("Socket", self._socket) in s: response = self._read_socket(read_timeout) if response == CLOSE_RESPONSE: break # Check if any requests need sending before proceeding # to process more responses. Otherwise the responses # may choke out the requests. See PR#633. - if self._read_sock in s: + if cast("Socket", self._read_sock) in s: self._send_request(read_timeout, connect_timeout) # Requests act as implicit pings. last_send = time.monotonic() @@ -674,9 +811,18 @@ def _connect_attempt(self, host, hostip, port, retry): raise finally: if self._socket is not None: - self._socket.close() - - def _connect(self, host, hostip, port): + # I think this is a bug in mypy, as the socket does get set up + # in self._connect, but it doesn't seem to be able to track + # that. + self._socket.close() # type: ignore[unreachable] + return None + + def _connect( + self, + host: str, + hostip: str, + port: int, + ) -> tuple[float, float]: client = self.client self.logger.info( "Connecting to %s(%s):%s, use_ssl: %r", @@ -707,7 +853,7 @@ def _connect(self, host, hostip, port): check_hostname=self.client.check_hostname, ) - self._socket.setblocking(0) + self._socket.setblocking(0) # type: ignore[arg-type] connect = Connect( 0, @@ -771,23 +917,36 @@ def _connect(self, host, hostip, port): return read_timeout, connect_timeout - def _authenticate_with_sasl(self, host, timeout): + def _authenticate_with_sasl(self, host: str, timeout: float) -> None: """Establish a SASL authenticated connection to the server.""" if not PURESASL_AVAILABLE: raise SASLException("Missing SASL support") - if "service" not in self.sasl_options: - self.sasl_options["service"] = "zookeeper" + # Although this can only be called if sasl_options is not None, we + # really should just have make self.sasl_options into an empty dict + # in the constructor. However, I want to avoid code changes in as + # much as possible. + if "service" not in self.sasl_options: # type: ignore[operator] + self.sasl_options["service"] = "zookeeper" # type: ignore[index] # NOTE: Zookeeper hardcoded the domain for Digest authentication # instead of using the hostname. See # zookeeper/util/SecurityUtils.java#L74 and Server/Client # initializations. - if self.sasl_options["mechanism"] == "DIGEST-MD5": + if ( + self.sasl_options["mechanism"] # type: ignore[index] + == "DIGEST-MD5" + ): host = "zk-sasl-md5" - sasl_cli = self.client.sasl_cli = puresasl.client.SASLClient( - host=host, **self.sasl_options + # I don't think the client.sasl_cli attribute is actually used + # anywhere else, so not sure why we need to set it on the client, + # but again, I want to avoid code changes as much as possible. + sasl_cli = ( + self.client.sasl_cli # type: ignore[attr-defined] + ) = puresasl.client.SASLClient( # type: ignore[no-untyped-call] + host=host, + **self.sasl_options, # type: ignore[arg-type] ) # Initialize the process with an empty challenge token diff --git a/kazoo/protocol/paths.py b/kazoo/protocol/paths.py index b8bf66507..7c47ce8a0 100644 --- a/kazoo/protocol/paths.py +++ b/kazoo/protocol/paths.py @@ -1,4 +1,4 @@ -def normpath(path, trailing=False): +def normpath(path: str, trailing: bool = False) -> str: """Normalize path, eliminating double slashes, etc.""" comps = path.split("/") new_comps = [] @@ -16,7 +16,7 @@ def normpath(path, trailing=False): return new_path -def join(a, *p): +def join(a: str, *p: str) -> str: """Join two or more pathname components, inserting '/' as needed. If any component is an absolute path, all previous path components @@ -34,23 +34,23 @@ def join(a, *p): return path -def isabs(s): +def isabs(s: str) -> bool: """Test whether a path is absolute""" return s.startswith("/") -def basename(p): +def basename(p: str) -> str: """Returns the final component of a pathname""" i = p.rfind("/") + 1 return p[i:] -def _prefix_root(root, path, trailing=False): +def _prefix_root(root: str, path: str, trailing: bool = False) -> str: """Prepend a root to a path.""" return normpath( join(_norm_root(root), path.lstrip("/")), trailing=trailing ) -def _norm_root(root): +def _norm_root(root: str) -> str: return normpath(join("/", root)) diff --git a/kazoo/protocol/serialization.py b/kazoo/protocol/serialization.py index 40e6360c2..29a55f35d 100644 --- a/kazoo/protocol/serialization.py +++ b/kazoo/protocol/serialization.py @@ -1,12 +1,23 @@ -"""Zookeeper Serializers, Deserializers, and NamedTuple objects""" -from collections import namedtuple +"""Zookeeper Serializers, Deserializers, and namedtuple objects + +Note: On python3.8, you can't do classvars with NamedTuple. + +FIXME As soon as we get off python3.8 we should change the namedtuple objects +to NamedTuple, as it should get better typechecking. +""" +from __future__ import annotations + import struct +from collections import namedtuple +from typing import ClassVar, Sequence, Union, TYPE_CHECKING -from kazoo.exceptions import EXCEPTIONS +from kazoo.exceptions import EXCEPTIONS, ZookeeperError from kazoo.protocol.states import ZnodeStat from kazoo.security import ACL from kazoo.security import Id +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc # Struct objects with formats compiled bool_struct = struct.Struct("B") @@ -21,20 +32,24 @@ stat_struct = struct.Struct("!qqqqiiiqiiq") -def read_string(buffer, offset): +def read_string(buffer: bytes, offset: int) -> tuple[str, int]: """Reads an int specified buffer into a string and returns the string and the new offset in the buffer""" length = int_struct.unpack_from(buffer, offset)[0] offset += int_struct.size if length < 0: - return None, offset + # A note: write_str sends a length of -1 to indicate a value of None + # was passed. Not entirely sure where this happens because none of the + # callers of read_string seem to expect a None value. + # Should be ignoring return-value but hound cli... + return None, offset # type: ignore else: index = offset offset += length return buffer[index : index + length].decode("utf-8"), offset -def read_acl(bytes, offset): +def read_acl(bytes: bytes, offset: int) -> tuple[ACL, int]: perms = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size scheme, offset = read_string(bytes, offset) @@ -42,7 +57,7 @@ def read_acl(bytes, offset): return ACL(perms, Id(scheme, id)), offset -def write_string(bytes): +def write_string(bytes: str | None) -> bytes: if not bytes: return int_struct.pack(-1) else: @@ -50,14 +65,14 @@ def write_string(bytes): return int_struct.pack(len(utf8_str)) + utf8_str -def write_buffer(bytes): +def write_buffer(bytes: bytes | None) -> bytes: if bytes is None: return int_struct.pack(-1) else: return int_struct.pack(len(bytes)) + bytes -def read_buffer(bytes, offset): +def read_buffer(bytes: bytes, offset: int) -> tuple[bytes | None, int]: length = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if length < 0: @@ -69,10 +84,10 @@ def read_buffer(bytes, offset): class Close(namedtuple("Close", "")): - type = -11 + type: ClassVar[int] = -11 @classmethod - def serialize(cls): + def serialize(self) -> bytes: return b"" @@ -80,10 +95,10 @@ def serialize(cls): class Ping(namedtuple("Ping", "")): - type = 11 + type: ClassVar[int] = 11 @classmethod - def serialize(cls): + def serialize(cls) -> bytes: return b"" @@ -97,9 +112,16 @@ class Connect( " time_out session_id passwd read_only", ) ): - type = None + protocol_version: int + last_zxid_seen: int + time_out: int + session_id: int + passwd: bytes + read_only: bool - def serialize(self): + type: int | None = None # Note: Not a classvar + + def serialize(self) -> bytearray: b = bytearray() b.extend( int_long_int_long_struct.pack( @@ -114,7 +136,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Connect, int]: proto_version, timeout, session_id = int_int_long_struct.unpack_from( bytes, offset ) @@ -133,9 +155,14 @@ def deserialize(cls, bytes, offset): class Create(namedtuple("Create", "path data acl flags")): - type = 1 + path: str + data: bytes | None + acl: Sequence[ACL] + flags: int + + type: ClassVar[int] = 1 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -150,59 +177,74 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> str: return read_string(bytes, offset)[0] class Delete(namedtuple("Delete", "path version")): - type = 2 + path: str + version: int - def serialize(self): + type: ClassVar[int] = 2 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b @classmethod - def deserialize(self, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> bool: return True class Exists(namedtuple("Exists", "path watcher")): - type = 3 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 3 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat | None: + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return stat if stat.czxid != -1 else None class GetData(namedtuple("GetData", "path watcher")): - type = 4 + path: str + watcher: WatchFunc | None - def serialize(self): + type: ClassVar[int] = 4 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class SetData(namedtuple("SetData", "path data version")): - type = 5 + path: str + data: bytes | None + version: int + + type: ClassVar[int] = 5 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -210,18 +252,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetACL(namedtuple("GetACL", "path")): - type = 6 + path: str - def serialize(self): + type: ClassVar[int] = 6 + + def serialize(self) -> bytearray: return bytearray(write_string(self.path)) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[list[ACL], ZnodeStat] | list[ACL]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -231,14 +277,18 @@ def deserialize(cls, bytes, offset): for c in range(count): acl, offset = read_acl(bytes, offset) acls.append(acl) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return acls, stat class SetACL(namedtuple("SetACL", "path acls version")): - type = 7 + path: str + acls: Sequence[ACL] + version: int + + type: ClassVar[int] = 7 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(len(self.acls))) @@ -252,21 +302,24 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): - return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + def deserialize(cls, bytes: bytes, offset: int) -> ZnodeStat: + return ZnodeStat(*stat_struct.unpack_from(bytes, offset)) class GetChildren(namedtuple("GetChildren", "path watcher")): - type = 8 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 8 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover @@ -280,54 +333,71 @@ def deserialize(cls, bytes, offset): class Sync(namedtuple("Sync", "path")): - type = 9 + path: str - def serialize(self): + type: ClassVar[int] = 9 + + def serialize(self) -> bytes: return write_string(self.path) @classmethod - def deserialize(cls, buffer, offset): + def deserialize(cls, buffer: bytes, offset: int) -> str: return read_string(buffer, offset)[0] class GetChildren2(namedtuple("GetChildren2", "path watcher")): - type = 12 + path: str + watcher: WatchFunc | None + + type: ClassVar[int] = 12 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend([1 if self.watcher else 0]) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[list[str], ZnodeStat] | list[str]: count = int_struct.unpack_from(bytes, offset)[0] offset += int_struct.size if count == -1: # pragma: nocover return [] - children = [] + children: list[str] = [] for c in range(count): child, offset = read_string(bytes, offset) children.append(child) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return children, stat class CheckVersion(namedtuple("CheckVersion", "path version")): - type = 13 + path: str + version: int + + type: ClassVar[int] = 13 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b +# FIXME Transaction class should move after Create2 +Transaction_Types = Union[Create, "Create2", Delete, SetData, CheckVersion] +Transaction_Response = Union[str, bool, ZnodeStat, ZookeeperError, None] + + class Transaction(namedtuple("Transaction", "operations")): - type = 14 + operations: list[Transaction_Types] + + type: ClassVar[int] = 14 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() for op in self.operations: b.extend( @@ -336,19 +406,19 @@ def serialize(self): return b + multiheader_struct.pack(-1, True, -1) @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> list[Transaction_Response]: header = MultiHeader(None, False, None) - results = [] - response = None + results: list[Transaction_Response] = [] + response: Transaction_Response = None while not header.done: if header.type == Create.type: response, offset = read_string(bytes, offset) elif header.type == Delete.type: response = True elif header.type == SetData.type: - response = ZnodeStat._make( - stat_struct.unpack_from(bytes, offset) - ) + response = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) offset += stat_struct.size elif header.type == CheckVersion.type: response = True @@ -362,8 +432,10 @@ def deserialize(cls, bytes, offset): return results @staticmethod - def unchroot(client, response): - resp = [] + def unchroot( + client: KazooClient, response: list[Transaction_Response] + ) -> list[Transaction_Response]: + resp: list[Transaction_Response] = [] for result in response: if isinstance(result, str): resp.append(client.unchroot(result)) @@ -373,9 +445,14 @@ def unchroot(client, response): class Create2(namedtuple("Create2", "path data acl flags")): - type = 15 + path: str + data: bytes | None + acl: Sequence[ACL] + flags: int - def serialize(self): + type: ClassVar[int] = 15 + + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(write_buffer(self.data)) @@ -390,18 +467,23 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[str, ZnodeStat]: path, offset = read_string(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return path, stat class Reconfig( namedtuple("Reconfig", "joining leaving new_members config_id") ): - type = 16 + joining: str | None + leaving: str | None + new_members: str | None + config_id: int + + type: ClassVar[int] = 16 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.joining)) b.extend(write_string(self.leaving)) @@ -410,16 +492,22 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, ZnodeStat]: data, offset = read_buffer(bytes, offset) - stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + stat = ZnodeStat(*stat_struct.unpack_from(bytes, offset)) return data, stat class Auth(namedtuple("Auth", "auth_type scheme auth")): - type = 100 + auth_type: int + scheme: str + auth: str - def serialize(self): + type: ClassVar[int] = 100 + + def serialize(self) -> bytes: return ( int_struct.pack(self.auth_type) + write_string(self.scheme) @@ -428,22 +516,30 @@ def serialize(self): class SASL(namedtuple("SASL", "challenge")): - type = 102 + challenge: bytes | None + + type: ClassVar[int] = 102 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_buffer(self.challenge)) return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize( + cls, bytes: bytes, offset: int + ) -> tuple[bytes | None, int]: challenge, offset = read_buffer(bytes, offset) return challenge, offset class Watch(namedtuple("Watch", "type state path")): + type: int + state: int + path: str + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[Watch, int]: """Given bytes and the current bytes offset, return the type, state, path, and new offset""" type, state = int_int_struct.unpack_from(bytes, offset) @@ -453,19 +549,27 @@ def deserialize(cls, bytes, offset): class ReplyHeader(namedtuple("ReplyHeader", "xid, zxid, err")): + xid: int + zxid: int + err: int + @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[ReplyHeader, int]: """Given bytes and the current bytes offset, return a :class:`ReplyHeader` instance and the new offset""" new_offset = offset + reply_header_struct.size return ( - cls._make(reply_header_struct.unpack_from(bytes, offset)), + cls(*reply_header_struct.unpack_from(bytes, offset)), new_offset, ) -class MultiHeader(namedtuple("MultiHeader", "type done err")): - def serialize(self): +class MultiHeader(namedtuple("MultiHeader", "type, done, err")): + type: int | None + done: bool + err: int | None + + def serialize(self) -> bytearray: b = bytearray() b.extend(int_struct.pack(self.type)) b.extend([1 if self.done else 0]) @@ -473,7 +577,7 @@ def serialize(self): return b @classmethod - def deserialize(cls, bytes, offset): + def deserialize(cls, bytes: bytes, offset: int) -> tuple[MultiHeader, int]: t, done, err = multiheader_struct.unpack_from(bytes, offset) offset += multiheader_struct.size return cls(t, done == 1, err), offset diff --git a/kazoo/protocol/states.py b/kazoo/protocol/states.py index 480a586e8..e2dc03fd5 100644 --- a/kazoo/protocol/states.py +++ b/kazoo/protocol/states.py @@ -1,8 +1,13 @@ """Kazoo State and Event objects""" -from collections import namedtuple +from __future__ import annotations -class KazooState(object): +from enum import Enum +from typing import Any, Callable, Iterable, NamedTuple + + +# This is a (str, Enum) for backwards compatibility. +class KazooState(str, Enum): """High level connection state values States inspired by Netflix Curator. @@ -33,7 +38,8 @@ class KazooState(object): LOST = "LOST" -class KeeperState(object): +# This is a (str, Enum) for backwards compatibility. +class KeeperState(str, Enum): """Zookeeper State Represents the Zookeeper state. Watch functions will receive a @@ -70,7 +76,8 @@ class KeeperState(object): EXPIRED_SESSION = "EXPIRED_SESSION" -class EventType(object): +# This is a (str, Enum) for backwards compatibility. +class EventType(str, Enum): """Zookeeper Event Represents a Zookeeper event. Events trigger watch functions which @@ -117,7 +124,7 @@ class EventType(object): } -class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): +class WatchedEvent(NamedTuple): """A change on ZooKeeper that a Watcher is able to respond to. The :class:`WatchedEvent` includes exactly what happened, the @@ -140,8 +147,12 @@ class WatchedEvent(namedtuple("WatchedEvent", ("type", "state", "path"))): """ + type: EventType + state: KeeperState + path: str | None + -class Callback(namedtuple("Callback", ("type", "func", "args"))): +class Callback(NamedTuple): """A callback that is handed to a handler for dispatch :param type: Type of the callback, currently is only 'watch' @@ -150,15 +161,12 @@ class Callback(namedtuple("Callback", ("type", "func", "args"))): """ + type: str + func: Callable[..., Any] + args: Iterable[Any] -class ZnodeStat( - namedtuple( - "ZnodeStat", - "czxid mzxid ctime mtime version" - " cversion aversion ephemeralOwner dataLength" - " numChildren pzxid", - ) -): + +class ZnodeStat(NamedTuple): """A ZnodeStat structure with convenience properties When getting the value of a znode from Zookeeper, the properties for @@ -216,38 +224,50 @@ class ZnodeStat( """ + czxid: int + mzxid: int + ctime: int + mtime: int + version: int + cversion: int + aversion: int + ephemeralOwner: int + dataLength: int + numChildren: int + pzxid: int + @property - def acl_version(self): + def acl_version(self) -> int: return self.aversion @property - def children_version(self): + def children_version(self) -> int: return self.cversion @property - def created(self): + def created(self) -> float: return self.ctime / 1000.0 @property - def last_modified(self): + def last_modified(self) -> float: return self.mtime / 1000.0 @property - def owner_session_id(self): + def owner_session_id(self) -> int | None: return self.ephemeralOwner or None @property - def creation_transaction_id(self): + def creation_transaction_id(self) -> int: return self.czxid @property - def last_modified_transaction_id(self): + def last_modified_transaction_id(self) -> int: return self.mzxid @property - def data_length(self): + def data_length(self) -> int: return self.dataLength @property - def children_count(self): + def children_count(self) -> int: return self.numChildren diff --git a/kazoo/py.typed b/kazoo/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/kazoo/recipe/barrier.py b/kazoo/recipe/barrier.py index 683e807b0..26ffc9357 100644 --- a/kazoo/recipe/barrier.py +++ b/kazoo/recipe/barrier.py @@ -4,12 +4,19 @@ :Status: Unknown """ + +from __future__ import annotations + import os import socket import uuid +from typing import Literal, TYPE_CHECKING from kazoo.exceptions import KazooException, NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent + +if TYPE_CHECKING: + from kazoo.client import KazooClient class Barrier(object): @@ -27,7 +34,7 @@ class Barrier(object): """ - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """Create a Kazoo Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -37,11 +44,11 @@ def __init__(self, client, path): self.client = client self.path = path - def create(self): + def create(self) -> None: """Establish the barrier if it doesn't exist already""" self.client.retry(self.client.ensure_path, self.path) - def remove(self): + def remove(self) -> bool: """Remove the barrier :returns: Whether the barrier actually needed to be removed. @@ -54,7 +61,7 @@ def remove(self): except NoNodeError: return False - def wait(self, timeout=None): + def wait(self, timeout: float | None = None) -> bool: """Wait on the barrier to be cleared :returns: True if the barrier has been cleared, otherwise @@ -64,7 +71,7 @@ def wait(self, timeout=None): """ cleared = self.client.handler.event_object() - def wait_for_clear(event): + def wait_for_clear(event: WatchedEvent) -> None: if event.type == EventType.DELETED: cleared.set() @@ -93,7 +100,13 @@ class DoubleBarrier(object): """ - def __init__(self, client, path, num_clients, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + num_clients: int, + identifier: str | None = None, + ): """Create a Double Barrier :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -118,7 +131,7 @@ def __init__(self, client, path, num_clients, identifier=None): self.node_name = uuid.uuid4().hex self.create_path = self.path + "/" + self.node_name - def enter(self): + def enter(self) -> None: """Enter the barrier, blocks until all nodes have entered""" try: self.client.retry(self._inner_enter) @@ -128,7 +141,7 @@ def enter(self): self._best_effort_cleanup() self.participating = False - def _inner_enter(self): + def _inner_enter(self) -> Literal[True]: # make sure our barrier parent node exists if not self.assured_path: self.client.ensure_path(self.path) @@ -145,7 +158,7 @@ def _inner_enter(self): except NodeExistsError: pass - def created(event): + def created(event: WatchedEvent) -> None: if event.type == EventType.CREATED: ready.set() @@ -159,7 +172,7 @@ def created(event): self.client.ensure_path(self.path + "/ready") return True - def leave(self): + def leave(self) -> None: """Leave the barrier, blocks until all nodes have left""" try: self.client.retry(self._inner_leave) @@ -168,7 +181,7 @@ def leave(self): self._best_effort_cleanup() self.participating = False - def _inner_leave(self): + def _inner_leave(self) -> bool: # Delete the ready node if its around try: self.client.delete(self.path + "/ready") @@ -188,7 +201,7 @@ def _inner_leave(self): ready = self.client.handler.event_object() - def deleted(event): + def deleted(event: WatchedEvent) -> None: if event.type == EventType.DELETED: ready.set() @@ -214,7 +227,7 @@ def deleted(event): # Wait for the lowest to be deleted ready.wait() - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.retry(self.client.delete, self.create_path) except NoNodeError: diff --git a/kazoo/recipe/cache.py b/kazoo/recipe/cache.py index 0a22a6c7e..cee96c82f 100644 --- a/kazoo/recipe/cache.py +++ b/kazoo/recipe/cache.py @@ -10,20 +10,42 @@ See also: http://curator.apache.org/curator-recipes/tree-cache.html """ + +from __future__ import annotations from __future__ import absolute_import import contextlib import functools import logging import operator +from typing import ( + Any, + Callable, + Generator, + Protocol, + TypeVar, + Tuple, + TYPE_CHECKING, + Union, + overload, +) + from kazoo.exceptions import NoNodeError, KazooException from kazoo.protocol.paths import _prefix_root, join as kazoo_join -from kazoo.protocol.states import KazooState, EventType +from kazoo.protocol.states import KazooState, EventType, ZnodeStat + +if TYPE_CHECKING: + from kazoo.client import KazooClient, WatchFunc + from kazoo.interfaces import IAsyncResult, Threadlike + from kazoo.protocol.states import WatchedEvent logger = logging.getLogger(__name__) +ReturnValue = TypeVar("ReturnValue") + + class TreeCache(object): """The cache of a ZooKeeper subtree. @@ -37,18 +59,18 @@ class TreeCache(object): _STOP = object() - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): self._client = client self._root = TreeNode.make_root(self, path) self._state = self.STATE_LATENT self._outstanding_ops = 0 self._is_initialized = False - self._error_listeners = [] - self._event_listeners = [] + self._error_listeners: list[Callable[[Exception], None]] = [] + self._event_listeners: list[Callable[[TreeEvent], None]] = [] self._task_queue = client.handler.queue_impl() - self._task_thread = None + self._task_thread: Threadlike | None = None - def start(self): + def start(self) -> None: """Starts the cache. The cache is not started automatically. You must call this method. @@ -85,7 +107,7 @@ def start(self): # without lock. self._in_background(self._root.on_created) - def close(self): + def close(self) -> None: """Closes the cache. A closed cache was detached from ZooKeeper's changes. And all nodes @@ -109,7 +131,9 @@ def close(self): # ZooKeeper actually. self._root.on_deleted() - def listen(self, listener): + def listen( + self, listener: Callable[[TreeEvent], None] + ) -> Callable[[TreeEvent], None]: """Registers a function to listen the cache events. The cache events are changes of local data. They are delivered from @@ -124,7 +148,9 @@ def listen(self, listener): self._event_listeners.append(listener) return listener - def listen_fault(self, listener): + def listen_fault( + self, listener: Callable[[Exception], None] + ) -> Callable[[Exception], None]: """Registers a function to listen the exceptions. It is possible to meet some exceptions during the cache running. You @@ -138,7 +164,9 @@ def listen_fault(self, listener): self._error_listeners.append(listener) return listener - def get_data(self, path, default=None): + def get_data( + self, path: str, default: NodeData | None = None + ) -> NodeData | None: """Gets data of a node from cache. :param path: The absolute path string. @@ -150,7 +178,9 @@ def get_data(self, path, default=None): node = self._find_node(path) return default if node is None else node._data - def get_children(self, path, default=None): + def get_children( + self, path: str, default: frozenset[str] | None = None + ) -> frozenset[str] | None: """Gets node children list from in-memory snapshot. :param path: The absolute path string. @@ -158,11 +188,14 @@ def get_children(self, path, default=None): does not exist. :raises ValueError: If the path is outside of this subtree. :returns: The :class:`frozenset` which including children names. + + # FIXME the default return value should be an empty frozenset, + # returning None is confusing. """ node = self._find_node(path) return default if node is None else frozenset(node._children) - def _find_node(self, path): + def _find_node(self, path: str) -> TreeNode | None: if not path.startswith(self._root._path): raise ValueError("outside of tree") striped_path = path[len(self._root._path) :].strip("/") @@ -170,25 +203,49 @@ def _find_node(self, path): current_node = self._root for node_name in splited_path: if node_name not in current_node._children: - return + return None current_node = current_node._children[node_name] return current_node - def _publish_event(self, event_type, event_data=None): + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: event = TreeEvent.make(event_type, event_data) if self._state != self.STATE_CLOSED: logger.debug("public event: %r", event) self._in_background(self._do_publish_event, event) - def _do_publish_event(self, event): + def _do_publish_event(self, event: TreeEvent) -> None: for listener in self._event_listeners: with handle_exception(self._error_listeners): listener(event) - def _in_background(self, func, *args, **kwargs): + @overload + def _in_background( + self, func: Callable[[TreeEvent], None], event: TreeEvent + ) -> None: + ... + + @overload + def _in_background(self, func: Callable[[], None]) -> None: + ... + + @overload + def _in_background( + self, + func: Callable[[str, str, IAsyncResult], None], + method_name: str, + path: str, + result: IAsyncResult, + ) -> None: + ... + + def _in_background( # type: ignore[misc] + self, func: Callable[..., Any], *args: Any, **kwargs: Any + ) -> None: self._task_queue.put((func, args, kwargs)) - def _do_background(self): + def _do_background(self) -> None: while True: with handle_exception(self._error_listeners): cb = self._task_queue.get() @@ -200,7 +257,7 @@ def _do_background(self): # release before possible idle del cb, func, args, kwargs - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.SUSPENDED: self._publish_event(TreeEvent.CONNECTION_SUSPENDED) elif state == KazooState.CONNECTED: @@ -212,6 +269,11 @@ def _session_watcher(self, state): self._publish_event(TreeEvent.CONNECTION_LOST) +class AsyncWatcher(Protocol): + def __call__(self, path: str, watch: WatchFunc | None) -> IAsyncResult: + ... + + class TreeNode(object): """The tree node record. @@ -234,28 +296,28 @@ class TreeNode(object): STATE_LIVE = 1 STATE_DEAD = 2 - def __init__(self, tree, path, parent): + def __init__(self, tree: TreeCache, path: str, parent: TreeNode | None): self._tree = tree self._path = path self._parent = parent - self._depth = parent._depth + 1 if parent else 0 - self._children = {} + self._depth: int = parent._depth + 1 if parent is not None else 0 + self._children: dict[str, TreeNode] = {} self._state = self.STATE_PENDING - self._data = None + self._data: NodeData | None = None @classmethod - def make_root(cls, tree, path): + def make_root(cls, tree: TreeCache, path: str) -> TreeNode: return cls(tree, path, None) - def on_reconnected(self): + def on_reconnected(self) -> None: self._refresh() for child in self._children.values(): child.on_reconnected() - def on_created(self): + def on_created(self) -> None: self._refresh() - def on_deleted(self): + def on_deleted(self) -> None: old_children, self._children = self._children, {} old_data, self._data = self._data, None @@ -278,37 +340,43 @@ def on_deleted(self): del self._parent._children[child] self._reset_watchers() - def _publish_event(self, *args, **kwargs): - return self._tree._publish_event(*args, **kwargs) + def _publish_event( + self, event_type: int, event_data: NodeData | None = None + ) -> None: + return self._tree._publish_event(event_type, event_data) - def _reset_watchers(self): + def _reset_watchers(self) -> None: client = self._tree._client for _watchers in (client._data_watchers, client._child_watchers): _path = _prefix_root(client.chroot, self._path) _watcher = _watchers.get(_path, set()) _watcher.discard(self._process_watch) - def _refresh(self): + def _refresh(self) -> None: self._refresh_data() self._refresh_children() - def _refresh_data(self): + def _refresh_data(self) -> None: self._call_client("get", self._path) - def _refresh_children(self): + def _refresh_children(self) -> None: # TODO max-depth checking support self._call_client("get_children", self._path) - def _call_client(self, method_name, path): + def _call_client(self, method_name: str, path: str) -> None: assert method_name in ("get", "get_children", "exists") self._tree._outstanding_ops += 1 callback = functools.partial( self._tree._in_background, self._process_result, method_name, path ) - method = getattr(self._tree._client, method_name + "_async") + # The typing for this is really bad but the type checker can + # understand it with a few hacks + method: AsyncWatcher = getattr( + self._tree._client, method_name + "_async" + ) method(path, watch=self._process_watch).rawlink(callback) - def _process_watch(self, watched_event): + def _process_watch(self, watched_event: WatchedEvent) -> None: logger.debug("process_watch: %r", watched_event) with handle_exception(self._tree._error_listeners): if watched_event.type == EventType.CREATED: @@ -321,7 +389,9 @@ def _process_watch(self, watched_event): elif watched_event.type == EventType.CHILD: self._refresh_children() - def _process_result(self, method_name, path, result): + def _process_result( + self, method_name: str, path: str, result: IAsyncResult + ) -> None: logger.debug("process_result: %s %s", method_name, path) if method_name == "exists": assert self._parent is None, "unexpected EXISTS on non-root" @@ -332,7 +402,7 @@ def _process_result(self, method_name, path, result): self.on_created() elif method_name == "get_children": if result.successful(): - children = result.get() + children: list[str] = result.get() for child in sorted(children): full_path = kazoo_join(path, child) if child not in self._children: @@ -367,9 +437,12 @@ def _process_result(self, method_name, path, result): self._publish_event(TreeEvent.INITIALIZED) -class TreeEvent(tuple): +# FIXME these Tuple-based classes would look a lot better using NamedTuple +# though the event_type in TreeEvent needs sorting. +class TreeEvent(Tuple[int, Union["NodeData", None]]): """The immutable event tuple of cache.""" + # FIXME These should be an enum. NODE_ADDED = 0 NODE_UPDATED = 1 NODE_REMOVED = 2 @@ -385,7 +458,9 @@ class TreeEvent(tuple): event_data = property(operator.itemgetter(1)) @classmethod - def make(cls, event_type, event_data): + def make( + cls, event_type: int, event_data: NodeData | None = None + ) -> TreeEvent: """Creates a new TreeEvent tuple. :returns: A :class:`~kazoo.recipe.cache.TreeEvent` instance. @@ -402,7 +477,7 @@ def make(cls, event_type, event_data): return cls((event_type, event_data)) -class NodeData(tuple): +class NodeData(Tuple[str, bytes, ZnodeStat]): """The immutable node data tuple of cache.""" #: The absolute path string of current node. @@ -415,7 +490,7 @@ class NodeData(tuple): stat = property(operator.itemgetter(2)) @classmethod - def make(cls, path, data, stat): + def make(cls, path: str, data: bytes, stat: ZnodeStat) -> NodeData: """Creates a new NodeData tuple. :returns: A :class:`~kazoo.recipe.cache.NodeData` instance. @@ -424,7 +499,9 @@ def make(cls, path, data, stat): @contextlib.contextmanager -def handle_exception(listeners): +def handle_exception( + listeners: list[Callable[[Exception], None]], +) -> Generator[None, None, None]: try: yield except Exception as e: diff --git a/kazoo/recipe/counter.py b/kazoo/recipe/counter.py index 3b2cc339c..531322522 100644 --- a/kazoo/recipe/counter.py +++ b/kazoo/recipe/counter.py @@ -4,9 +4,20 @@ :Status: Unknown """ + +from __future__ import annotations + +import struct +from typing import Union, TYPE_CHECKING + from kazoo.exceptions import BadVersionError from kazoo.retry import ForceRetryError -import struct + +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +Number = Union[int, float] class Counter(object): @@ -58,7 +69,13 @@ class Counter(object): """ - def __init__(self, client, path, default=0, support_curator=False): + def __init__( + self, + client: KazooClient, + path: str, + default: Number = 0, + support_curator: bool = False, + ): """Create a Kazoo Counter :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -74,22 +91,24 @@ def __init__(self, client, path, default=0, support_curator=False): self.default_type = type(default) self.support_curator = support_curator self._ensured_path = False - self.pre_value = None - self.post_value = None + self.pre_value: Number | None = None + self.post_value: Number | None = None if self.support_curator and not isinstance(self.default, int): raise TypeError( "when support_curator is enabled the default " "type must be an int" ) - def _ensure_node(self): + def _ensure_node(self) -> None: if not self._ensured_path: # make sure our node exists self.client.ensure_path(self.path) self._ensured_path = True - def _value(self): + def _value(self) -> tuple[Number, int]: self._ensure_node() + # This is astonishingly hard to follow... + old: Union[bytes, str, Number] old, stat = self.client.get(self.path) if self.support_curator: old = struct.unpack(">i", old)[0] if old != b"" else self.default @@ -100,16 +119,16 @@ def _value(self): return data, version @property - def value(self): + def value(self) -> Number: return self._value()[0] - def _change(self, value): + def _change(self, value: Number) -> Counter: if not isinstance(value, self.default_type): raise TypeError("invalid type for value change") self.client.retry(self._inner_change, value) return self - def _inner_change(self, value): + def _inner_change(self, value: Number) -> None: self.pre_value, version = self._value() post_value = self.pre_value + value if self.support_curator: @@ -123,10 +142,10 @@ def _inner_change(self, value): raise ForceRetryError() self.post_value = post_value - def __add__(self, value): + def __add__(self, value: Number) -> Counter: """Add value to counter.""" return self._change(value) - def __sub__(self, value): + def __sub__(self, value: Number) -> Counter: """Subtract value from counter.""" return self._change(-value) diff --git a/kazoo/recipe/election.py b/kazoo/recipe/election.py index 93bb72580..1e28517b7 100644 --- a/kazoo/recipe/election.py +++ b/kazoo/recipe/election.py @@ -4,8 +4,19 @@ :Status: Unknown """ + +from __future__ import annotations + +from typing import Callable, TYPE_CHECKING +from typing_extensions import ParamSpec + from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + +GenericArgs = ParamSpec("GenericArgs") + class Election(object): """Kazoo Basic Leader Election @@ -22,7 +33,12 @@ class Election(object): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): """Create a Kazoo Leader Election :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -34,7 +50,12 @@ def __init__(self, client, path, identifier=None): """ self.lock = client.Lock(path, identifier) - def run(self, func, *args, **kwargs): + def run( + self, + func: Callable[GenericArgs, None], + *args: GenericArgs.args, + **kwargs: GenericArgs.kwargs, + ) -> None: """Contend for the leadership This call will block until either this contender is cancelled @@ -57,7 +78,7 @@ def run(self, func, *args, **kwargs): except CancelledError: pass - def cancel(self): + def cancel(self) -> None: """Cancel participation in the election .. note:: @@ -68,7 +89,7 @@ def cancel(self): """ self.lock.cancel() - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders in the election diff --git a/kazoo/recipe/lease.py b/kazoo/recipe/lease.py index ce7fe567c..1e4c9cf85 100644 --- a/kazoo/recipe/lease.py +++ b/kazoo/recipe/lease.py @@ -5,12 +5,25 @@ :Status: Beta """ + +from __future__ import annotations + import datetime import json import socket +from typing import Callable, TypedDict, TYPE_CHECKING, cast from kazoo.exceptions import CancelledError +if TYPE_CHECKING: + from kazoo.client import KazooClient + + +class Lease(TypedDict): + version: int + holder: str + end: str + class NonBlockingLease(object): """Exclusive lease that does not block. @@ -48,11 +61,11 @@ class NonBlockingLease(object): def __init__( self, - client, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + path: str, + duration: datetime.timedelta, + identifier: str | None = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): """Create a non-blocking lease. @@ -71,7 +84,14 @@ def __init__( self.obtained = False self._attempt_obtaining(client, path, duration, ident, utcnow) - def _attempt_obtaining(self, client, path, duration, ident, utcnow): + def _attempt_obtaining( + self, + client: KazooClient, + path: str, + duration: datetime.timedelta, + ident: str, + utcnow: Callable[[], datetime.datetime], + ) -> None: client.ensure_path(path) holder_path = path + "/lease_holder" lock = client.Lock(path, ident) @@ -92,7 +112,7 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): return client.delete(holder_path) end_lease = (now + duration).strftime(self._date_format) - new_data = { + new_data: Lease = { "version": self._version, "holder": ident, "end": end_lease, @@ -103,18 +123,13 @@ def _attempt_obtaining(self, client, path, duration, ident, utcnow): except CancelledError: pass - def _encode(self, data_dict): + def _encode(self, data_dict: Lease) -> bytes: return json.dumps(data_dict).encode(self._byte_encoding) - def _decode(self, raw): - return json.loads(raw.decode(self._byte_encoding)) - - # Python 2.x - def __nonzero__(self): - return self.obtained + def _decode(self, raw: bytes) -> Lease: + return cast("Lease", json.loads(raw.decode(self._byte_encoding))) - # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained @@ -140,12 +155,12 @@ class MultiNonBlockingLease(object): def __init__( self, - client, - count, - path, - duration, - identifier=None, - utcnow=datetime.datetime.utcnow, + client: KazooClient, + count: int, + path: str, + duration: datetime.timedelta, + identifier: str | None = None, + utcnow: Callable[[], datetime.datetime] = datetime.datetime.utcnow, ): self.obtained = False for num in range(count): @@ -160,10 +175,5 @@ def __init__( self.obtained = True break - # Python 2.x - def __nonzero__(self): - return self.obtained - - # Python 3.x - def __bool__(self): + def __bool__(self) -> bool: return self.obtained diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 1f5247021..6f3236db2 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -14,9 +14,19 @@ and/or the lease has been lost. """ + +from __future__ import annotations + import re import time import uuid +from typing import ( + Iterable, + Literal, + Pattern, + TYPE_CHECKING, +) +from types import TracebackType from kazoo.exceptions import ( CancelledError, @@ -24,27 +34,37 @@ LockTimeout, NoNodeError, ) -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent from kazoo.retry import ( ForceRetryError, KazooRetry, RetryFailedError, ) +if TYPE_CHECKING: + from kazoo.client import KazooClient + class _Watch(object): - def __init__(self, duration=None): + def __init__(self, duration: float | None = None): self.duration = duration - self.started_at = None + self.started_at: float | None = None - def start(self): + def start(self) -> None: self.started_at = time.monotonic() - def leftover(self): + def leftover(self) -> float | None: if self.duration is None: return None else: - elapsed = time.monotonic() - self.started_at + # We should probably set started_at to either 0 or + # time.monotonic() in __init__ to avoid the type ignore + # here, but this is a private class and it's pretty clear + # that start() should be called before leftover() so I'm + # not sure it's worth it. + elapsed = ( + time.monotonic() - self.started_at # type: ignore[operator] + ) return max(0, self.duration - elapsed) @@ -77,7 +97,13 @@ class Lock(object): # sequence number. Involved in read/write locks. _EXCLUDE_NAMES = ["__lock__"] - def __init__(self, client, path, identifier=None, extra_lock_patterns=()): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + extra_lock_patterns: Iterable[str] = (), + ): """Create a Kazoo lock. :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -97,10 +123,10 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): """ self.client = client self.path = path - self._exclude_names = set( + self._exclude_names: set[str] = set( self._EXCLUDE_NAMES + list(extra_lock_patterns) ) - self._contenders_re = re.compile( + self._contenders_re: Pattern[str] = re.compile( r"(?:{patterns})(-?\d{{10}})$".format( patterns="|".join(self._exclude_names) ) @@ -109,7 +135,7 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): # some data is written to the node. this can be queried via # contenders() to see who is contending for the lock self.data = str(identifier or "").encode("utf-8") - self.node = None + self.node: str | None = None self.wake_event = client.handler.event_object() @@ -129,16 +155,21 @@ def __init__(self, client, path, identifier=None, extra_lock_patterns=()): ) self._acquire_method_lock = client.handler.lock_object() - def _ensure_path(self): + def _ensure_path(self) -> None: self.client.ensure_path(self.path) self.assured_path = True - def cancel(self): + def cancel(self) -> None: """Cancel a pending lock acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None, ephemeral=True): + def acquire( + self, + blocking: bool = True, + timeout: float | None = None, + ephemeral: bool = True, + ) -> bool: """ Acquire the lock. By defaults blocks and waits forever. @@ -204,11 +235,16 @@ def acquire(self, blocking=True, timeout=None, ephemeral=True): finally: self._acquire_method_lock.release() - def _watch_session(self, state): + def _watch_session(self, state: KazooState) -> bool: self.wake_event.set() return True - def _inner_acquire(self, blocking, timeout, ephemeral=True): + def _inner_acquire( + self, + blocking: bool, + timeout: float | None, + ephemeral: bool = True, + ) -> bool: # wait until it's our chance to get it.. if self.is_acquired: if not blocking: @@ -219,7 +255,7 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): if not self.assured_path: self._ensure_path() - node = None + node: str | None = None if self.create_tried: node = self._find_node() else: @@ -265,10 +301,10 @@ def _inner_acquire(self, blocking, timeout, ephemeral=True): finally: self.client.remove_listener(self._watch_session) - def _watch_predecessor(self, event): + def _watch_predecessor(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_predecessor(self, node): + def _get_predecessor(self, node: str) -> str | None: """returns `node`'s predecessor or None Note: This handle the case where the current lock is not a contender @@ -277,7 +313,7 @@ def _get_predecessor(self, node): """ node_sequence = node[len(self.prefix) :] children = self.client.get_children(self.path) - found_self = False + found_self: Literal[False] | re.Match[str] | None = False # Filter out the contenders using the computed regex contender_matches = [] for child in children: @@ -308,17 +344,17 @@ def _get_predecessor(self, node): sorted_matches = sorted(contender_matches, key=lambda m: m.groups()) return sorted_matches[-1].string - def _find_node(self): + def _find_node(self) -> str | None: children = self.client.get_children(self.path) for child in children: if child.startswith(self.prefix): return child return None - def _delete_node(self, node): + def _delete_node(self, node: str) -> None: self.client.delete(self.path + "/" + node) - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: node = self.node or self._find_node() if node: @@ -326,16 +362,18 @@ def _best_effort_cleanup(self): except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lock immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: - self._delete_node(self.node) + # I don't think it's possible for self.node to be None here if + # self.is_acquired is true. + self._delete_node(self.node) # type: ignore[arg-type] except NoNodeError: # pragma: nocover pass @@ -343,7 +381,7 @@ def _inner_release(self): self.node = None return True - def contenders(self): + def contenders(self) -> list[str]: """Return an ordered list of the current contenders for the lock. @@ -390,11 +428,17 @@ def contenders(self): return contenders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> bool | None: self.release() + return None class WriteLock(Lock): @@ -492,7 +536,13 @@ class Semaphore(object): """ - def __init__(self, client, path, identifier=None, max_leases=1): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + max_leases: int = 1, + ): """Create a Kazoo Lock :param client: A :class:`~kazoo.client.KazooClient` instance. @@ -528,7 +578,7 @@ def __init__(self, client, path, identifier=None, max_leases=1): self.cancelled = False self._session_expired = False - def _ensure_path(self): + def _ensure_path(self) -> None: result = self.client.ensure_path(self.path) self.assured_path = True if result is True: @@ -549,12 +599,16 @@ def _ensure_path(self): else: self.client.set(self.path, str(self.max_leases).encode("utf-8")) - def cancel(self): + def cancel(self) -> None: """Cancel a pending semaphore acquire.""" self.cancelled = True self.wake_event.set() - def acquire(self, blocking=True, timeout=None): + def acquire( + self, + blocking: bool = True, + timeout: float | None = None, + ) -> bool: """Acquire the semaphore. By defaults blocks and waits forever. :param blocking: Block until semaphore is obtained or @@ -592,7 +646,11 @@ def acquire(self, blocking=True, timeout=None): return self.is_acquired - def _inner_acquire(self, blocking, timeout=None): + def _inner_acquire( + self, + blocking: bool, + timeout: float | None = None, + ) -> bool: """Inner loop that runs from the top anytime a command hits a retryable Zookeeper exception.""" self._session_expired = False @@ -607,7 +665,12 @@ def _inner_acquire(self, blocking, timeout=None): w = _Watch(duration=timeout) w.start() - lock = self.client.Lock(self.lock_path, self.data) + # FIXME This is passing bytes data, but self.client.Lock expects a str, + # which I think is a bug in this code. However, I don't want to + # change any code at this point, so we just ignore the type error here. + lock = self.client.Lock( + self.lock_path, self.data # type: ignore[arg-type] + ) try: gotten = lock.acquire(blocking=blocking, timeout=w.leftover()) if not gotten: @@ -633,10 +696,10 @@ def _inner_acquire(self, blocking, timeout=None): finally: lock.release() - def _watch_lease_change(self, event): + def _watch_lease_change(self, event: WatchedEvent) -> None: self.wake_event.set() - def _get_lease(self, data=None): + def _get_lease(self) -> bool: # Make sure the session is still valid if self._session_expired: raise ForceRetryError("Retry on session loss at top") @@ -665,25 +728,26 @@ def _get_lease(self, data=None): # Return current state return self.is_acquired - def _watch_session(self, state): + def _watch_session(self, state: KazooState) -> bool | None: if state == KazooState.LOST: self._session_expired = True self.wake_event.set() # Return true to de-register return True + return None - def _best_effort_cleanup(self): + def _best_effort_cleanup(self) -> None: try: self.client.delete(self.create_path) except KazooException: # pragma: nocover pass - def release(self): + def release(self) -> bool: """Release the lease immediately.""" return self.client.retry(self._inner_release) - def _inner_release(self): + def _inner_release(self) -> bool: if not self.is_acquired: return False try: @@ -693,7 +757,7 @@ def _inner_release(self): self.is_acquired = False return True - def lease_holders(self): + def lease_holders(self) -> list[str]: """Return an unordered list of the current lease holders. .. note:: @@ -716,8 +780,13 @@ def lease_holders(self): pass return lease_holders - def __enter__(self): + def __enter__(self) -> None: self.acquire() - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.release() diff --git a/kazoo/recipe/partitioner.py b/kazoo/recipe/partitioner.py index 21dc6ef4a..bcc3b1917 100644 --- a/kazoo/recipe/partitioner.py +++ b/kazoo/recipe/partitioner.py @@ -17,20 +17,38 @@ so that no two workers own the same queue. """ + +from __future__ import annotations + from functools import partial import logging import os import socket +from enum import Enum +from typing import ( + Callable, + Generic, + Iterator, + Iterable, + TYPE_CHECKING, + TypeVar, +) from kazoo.exceptions import KazooException, LockTimeout from kazoo.protocol.states import KazooState from kazoo.recipe.watchers import PatientChildrenWatch +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import Event, IAsyncResult + from kazoo.recipe.lock import Lock + from _typeshed import SupportsRichComparisonT log = logging.getLogger(__name__) -class PartitionState(object): +# This is a (str, Enum) for backwards compatibility. +class PartitionState(str, Enum): """High level partition state values .. attribute:: ALLOCATING @@ -63,7 +81,20 @@ class PartitionState(object): FAILURE = "FAILURE" -class SetPartitioner(object): +if TYPE_CHECKING: + # FIXME There should be a better way of doing this, but it's not clear how + # we can achieve what we want (specify that PartitionDataT is a type which + # is a SupportsRichComparisonT type) but clearly you can only do that with + # typechecking on. + PartitionDataT = TypeVar( + "PartitionDataT", + bound=SupportsRichComparisonT, # type:ignore[valid-type] + ) +else: + PartitionDataT = TypeVar("PartitionDataT") + + +class SetPartitioner(Generic[PartitionDataT]): """Partitions a set amongst members of a party This class will partition a set amongst members of a party such @@ -139,14 +170,18 @@ class SetPartitioner(object): def __init__( self, - client, - path, - set, - partition_func=None, - identifier=None, - time_boundary=30, - max_reaction_time=1, - state_change_event=None, + client: KazooClient, + path: str, + set: Iterable[PartitionDataT], + partition_func: Callable[ + [str, Iterable[str], Iterable[PartitionDataT]], + list[PartitionDataT], + ] + | None = None, + identifier: str | None = None, + time_boundary: float = 30, + max_reaction_time: float = 1, + state_change_event: Event | None = None, ): """Create a :class:`~SetPartitioner` instance @@ -176,13 +211,13 @@ def __init__( self._client = client self._path = path self._set = set - self._partition_set = [] + self._partition_set: list[PartitionDataT] = [] self._partition_func = partition_func or self._partitioner self._identifier = identifier or "%s-%s" % ( socket.getfqdn(), os.getpid(), ) - self._locks = [] + self._locks: list[Lock] = [] self._lock_path = "/".join([path, "locks"]) self._party_path = "/".join([path, "party"]) self._time_boundary = time_boundary @@ -208,33 +243,33 @@ def __init__( # so we know when we're ready self._child_watching(self._allocate_transition, client_handler=True) - def __iter__(self): + def __iter__(self) -> Iterator[PartitionDataT]: """Return the partitions in this partition set""" for partition in self._partition_set: yield partition @property - def failed(self): + def failed(self) -> bool: """Corresponds to the :attr:`PartitionState.FAILURE` state""" return self.state == PartitionState.FAILURE @property - def release(self): + def release(self) -> bool: """Corresponds to the :attr:`PartitionState.RELEASE` state""" return self.state == PartitionState.RELEASE @property - def allocating(self): + def allocating(self) -> bool: """Corresponds to the :attr:`PartitionState.ALLOCATING` state""" return self.state == PartitionState.ALLOCATING @property - def acquired(self): + def acquired(self) -> bool: """Corresponds to the :attr:`PartitionState.ACQUIRED` state""" return self.state == PartitionState.ACQUIRED - def wait_for_acquire(self, timeout=30): + def wait_for_acquire(self, timeout: float = 30) -> None: """Wait for the set to be partitioned and acquired :param timeout: How long to wait before returning. @@ -243,7 +278,7 @@ def wait_for_acquire(self, timeout=30): """ self._acquire_event.wait(timeout) - def release_set(self): + def release_set(self) -> None: """Call to release the set This method begins the step of allocating once the set has @@ -263,12 +298,12 @@ def release_set(self): self._set_state(PartitionState.ALLOCATING) self._child_watching(self._allocate_transition, client_handler=True) - def finish(self): + def finish(self) -> None: """Call to release the set and leave the party""" self._release_locks() self._fail_out() - def _fail_out(self): + def _fail_out(self) -> None: with self._state_change: self._set_state(PartitionState.FAILURE) if self._party.participating: @@ -277,7 +312,7 @@ def _fail_out(self): except KazooException: # pragma: nocover pass - def _allocate_transition(self, result): + def _allocate_transition(self, result: IAsyncResult) -> None: """Called when in allocating mode, and the children settled""" # Did we get an exception waiting for children to settle? @@ -288,7 +323,7 @@ def _allocate_transition(self, result): children, async_result = result.get() children_changed = self._client.handler.event_object() - def updated(result): + def updated(result: IAsyncResult) -> None: with self._state_change: children_changed.set() if self.acquired: @@ -307,7 +342,7 @@ def updated(result): # Check whether the state has changed during the lock acquisition # and abort the process if so. - def abort_if_needed(): + def abort_if_needed() -> bool: if self.state_id == state_id: if children_changed.is_set(): # The party has changed. Repartitioning... @@ -365,7 +400,7 @@ def abort_if_needed(): # This mustn't happen. Means a logical error. self._fail_out() - def _release_locks(self): + def _release_locks(self) -> None: """Attempt to completely remove all the locks""" self._acquire_event.clear() for lock in self._locks[:]: @@ -378,7 +413,7 @@ def _release_locks(self): else: self._locks.remove(lock) - def _abort_lock_acquisition(self): + def _abort_lock_acquisition(self) -> None: """Called during lock acquisition if a party change occurs""" self._release_locks() @@ -391,7 +426,13 @@ def _abort_lock_acquisition(self): self._child_watching(self._allocate_transition, client_handler=True) - def _child_watching(self, func=None, client_handler=False): + # FIXME This is only ever called with func=self._allocation_transition, but + # I didn't want to change the code. + def _child_watching( + self, + func: Callable[[IAsyncResult], None] | None = None, + client_handler: bool = False, + ) -> IAsyncResult: """Called when children are being watched to stabilize This actually returns immediately, child watcher spins up a @@ -410,11 +451,15 @@ def _child_watching(self, func=None, client_handler=False): # to ensure that the rawlink's it might use won't be # blocked if client_handler: - func = partial(self._client.handler.spawn, func) + # FIXME This feels wrong, but it may be because partial is + # confusing things. + func = partial( # type: ignore[assignment] + self._client.handler.spawn, func + ) asy.rawlink(func) return asy - def _establish_sessionwatch(self, state): + def _establish_sessionwatch(self, state: KazooState) -> bool: """Register ourself to listen for session events, we shut down if we become lost""" with self._state_change: @@ -427,7 +472,12 @@ def _establish_sessionwatch(self, state): return state == KazooState.LOST - def _partitioner(self, identifier, members, partitions): + def _partitioner( + self, + identifier: str, + members: Iterable[str], + partitions: Iterable[PartitionDataT], + ) -> list[PartitionDataT]: # Ensure consistent order of partitions/members all_partitions = sorted(partitions) workers = sorted(members) @@ -437,7 +487,7 @@ def _partitioner(self, identifier, members, partitions): # skipping the other workers return all_partitions[i :: len(workers)] - def _set_state(self, state): + def _set_state(self, state: PartitionState) -> None: self.state = state self.state_id += 1 self.state_change_event.set() diff --git a/kazoo/recipe/party.py b/kazoo/recipe/party.py index 2a0f5dfb6..1fc1340b7 100644 --- a/kazoo/recipe/party.py +++ b/kazoo/recipe/party.py @@ -7,15 +7,27 @@ used for determining members of a party. """ + +from __future__ import annotations + import uuid +from typing import Iterator, TYPE_CHECKING from kazoo.exceptions import NodeExistsError, NoNodeError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseParty(object): """Base implementation of a party.""" - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The party path to use. @@ -29,44 +41,52 @@ def __init__(self, client, path, identifier=None): self.ensured_path = False self.participating = False - def _ensure_parent(self): + def _ensure_parent(self) -> None: if not self.ensured_path: # make sure our parent node exists self.client.ensure_path(self.path) self.ensured_path = True - def join(self): + def join(self) -> None: """Join the party""" return self.client.retry(self._inner_join) - def _inner_join(self): + def _inner_join(self) -> None: self._ensure_parent() try: - self.client.create(self.create_path, self.data, ephemeral=True) + # This and the #type: ignore[attr-defined] below could be removed + # by setting up create_path in the constructor but trying to avoid + # changing the code too much. It does actually cause later versions + # of pylint to error though. + self.client.create( + self.create_path, # type: ignore[attr-defined] + self.data, + ephemeral=True, + ) self.participating = True except NodeExistsError: # node was already created, perhaps we are recovering from a # suspended connection self.participating = True - def leave(self): + def leave(self) -> bool: """Leave the party""" self.participating = False return self.client.retry(self._inner_leave) - def _inner_leave(self): + def _inner_leave(self) -> bool: try: - self.client.delete(self.create_path) + self.client.delete(self.create_path) # type: ignore[attr-defined] except NoNodeError: return False return True - def __len__(self): + def __len__(self) -> int: """Return a count of participating clients""" self._ensure_parent() return len(self._get_children()) - def _get_children(self): + def _get_children(self) -> list[str]: return self.client.retry(self.client.get_children, self.path) @@ -75,12 +95,14 @@ class Party(BaseParty): _NODE_NAME = "__party__" - def __init__(self, client, path, identifier=None): + def __init__( + self, client: KazooClient, path: str, identifier: str | None = None + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = uuid.uuid4().hex + self._NODE_NAME self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' data values""" self._ensure_parent() children = self._get_children() @@ -93,7 +115,7 @@ def __iter__(self): except NoNodeError: # pragma: nocover pass - def _get_children(self): + def _get_children(self) -> list[str]: children = BaseParty._get_children(self) return [c for c in children if self._NODE_NAME in c] @@ -109,12 +131,17 @@ class ShallowParty(BaseParty): """ - def __init__(self, client, path, identifier=None): + def __init__( + self, + client: KazooClient, + path: str, + identifier: str | None = None, + ): BaseParty.__init__(self, client, path, identifier=identifier) self.node = "-".join([uuid.uuid4().hex, self.data.decode("utf-8")]) self.create_path = self.path + "/" + self.node - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Get a list of participating clients' identifiers""" self._ensure_parent() children = self._get_children() diff --git a/kazoo/recipe/queue.py b/kazoo/recipe/queue.py index 30d3066e4..85a866764 100644 --- a/kazoo/recipe/queue.py +++ b/kazoo/recipe/queue.py @@ -9,17 +9,24 @@ See: https://github.com/python-zk/kazoo/issues/175 """ + +from __future__ import annotations + import uuid +from typing import TYPE_CHECKING from kazoo.exceptions import NoNodeError, NodeExistsError -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent from kazoo.retry import ForceRetryError +if TYPE_CHECKING: + from kazoo.client import KazooClient + class BaseQueue(object): """A common base class for queue implementations.""" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. @@ -27,10 +34,10 @@ def __init__(self, client, path): self.client = client self.path = path self._entries_path = path - self.structure_paths = (self.path,) + self.structure_paths: tuple[str, ...] = (self.path,) self.ensured_path = False - def _check_put_arguments(self, value, priority=100): + def _check_put_arguments(self, value: bytes, priority: int = 100) -> None: if not isinstance(value, bytes): raise TypeError("value must be a byte string") if not isinstance(priority, int): @@ -38,14 +45,14 @@ def _check_put_arguments(self, value, priority=100): elif priority < 0 or priority > 999: raise ValueError("priority must be between 0 and 999") - def _ensure_paths(self): + def _ensure_paths(self) -> None: if not self.ensured_path: # make sure our parent / internal structure nodes exists for path in self.structure_paths: self.client.ensure_path(path) self.ensured_path = True - def __len__(self): + def __len__(self) -> int: self._ensure_paths() _, stat = self.client.retry(self.client.get, self._entries_path) return stat.children_count @@ -62,19 +69,19 @@ class Queue(BaseQueue): prefix = "entry-" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(Queue, self).__init__(client, path) - self._children = [] + self._children: list[str] = [] - def __len__(self): + def __len__(self) -> int: """Return queue size.""" return super(Queue, self).__len__() - def get(self): + def get(self) -> bytes | None: """ Get item data and remove an item from the queue. @@ -84,7 +91,7 @@ def get(self): self._ensure_paths() return self.client.retry(self._inner_get) - def _inner_get(self): + def _inner_get(self) -> bytes | None: if not self._children: self._children = self.client.retry( self.client.get_children, self.path @@ -105,7 +112,7 @@ def _inner_get(self): self._children.pop(0) return data - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an item into the queue. :param value: Byte string to put into the queue. @@ -150,26 +157,26 @@ class LockingQueue(BaseQueue): entries = "/entries" entry = "entry" - def __init__(self, client, path): + def __init__(self, client: KazooClient, path: str): """ :param client: A :class:`~kazoo.client.KazooClient` instance. :param path: The queue path to use in ZooKeeper. """ super(LockingQueue, self).__init__(client, path) self.id = uuid.uuid4().hex.encode() - self.processing_element = None + self.processing_element: tuple[str, bytes] | None = None self._lock_path = self.path + self.lock self._entries_path = self.path + self.entries self.structure_paths = (self._lock_path, self._entries_path) - def __len__(self): + def __len__(self) -> int: """Returns the current length of the queue. :returns: queue size (includes locked entries count). """ return super(LockingQueue, self).__len__() - def put(self, value, priority=100): + def put(self, value: bytes, priority: int = 100) -> None: """Put an entry into the queue. :param value: Byte string to put into the queue. @@ -189,7 +196,7 @@ def put(self, value, priority=100): sequence=True, ) - def put_all(self, values, priority=100): + def put_all(self, values: list[bytes], priority: int = 100) -> None: """Put several entries into the queue. The action only succeeds if all entries where put into the queue. @@ -221,7 +228,7 @@ def put_all(self, values, priority=100): sequence=True, ) - def get(self, timeout=None): + def get(self, timeout: float | None = None) -> bytes | None: """Locks and gets an entry from the queue. If a previously got entry was not consumed, this method will return that entry. @@ -237,7 +244,7 @@ def get(self, timeout=None): else: return self._inner_get(timeout) - def holds_lock(self): + def holds_lock(self) -> bool: """Checks if a node still holds the lock. :returns: True if a node still holds the lock, False otherwise. @@ -251,7 +258,7 @@ def holds_lock(self): value, stat = self.client.retry(self.client.get, lock_path) return value == self.id - def consume(self): + def consume(self) -> bool: """Removes a currently processing entry from the queue. :returns: True if element was removed successfully, False otherwise. @@ -271,7 +278,7 @@ def consume(self): else: return False - def release(self): + def release(self) -> bool: """Removes the lock from currently processed item without consuming it. :returns: True if the lock was removed successfully, False otherwise. @@ -289,13 +296,13 @@ def release(self): else: return False - def _inner_get(self, timeout): + def _inner_get(self, timeout: float | None) -> bytes | None: flag = self.client.handler.event_object() lock = self.client.handler.lock_object() canceled = False value = [] - def check_for_updates(event): + def check_for_updates(event: WatchedEvent | None) -> None: if event is not None and event.type != EventType.CHILD: return with lock: @@ -330,8 +337,8 @@ def check_for_updates(event): retVal = value[0][1] return retVal - def _filter_locked(self, values, taken): - taken = set(taken) + def _filter_locked(self, values: list[str], taken: list[str]) -> list[str]: + taken = set(taken) # type: ignore[assignment] available = sorted(values) return ( available @@ -339,7 +346,7 @@ def _filter_locked(self, values, taken): else [x for x in available if x not in taken] ) - def _take(self, id_): + def _take(self, id_: str) -> tuple[str, bytes] | None: try: self.client.create( "{path}/{id}".format(path=self._lock_path, id=id_), diff --git a/kazoo/recipe/watchers.py b/kazoo/recipe/watchers.py index d4cb0300e..77f54547f 100644 --- a/kazoo/recipe/watchers.py +++ b/kazoo/recipe/watchers.py @@ -10,25 +10,45 @@ will result in an exception being thrown. """ + +from __future__ import annotations + from functools import partial, wraps import logging import time import warnings +from typing import ( + Any, + List, + Callable, + Optional, + Union, + TYPE_CHECKING, + overload, +) +from typing_extensions import ParamSpec, deprecated from kazoo.exceptions import ConnectionClosedError, NoNodeError, KazooException -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, WatchedEvent, ZnodeStat from kazoo.retry import KazooRetry +if TYPE_CHECKING: + from kazoo.client import KazooClient + from kazoo.interfaces import IAsyncResult log = logging.getLogger(__name__) _STOP_WATCHING = object() +GenericArgs = ParamSpec("GenericArgs") + -def _ignore_closed(func): +def _ignore_closed( + func: Callable[GenericArgs, None] +) -> Callable[GenericArgs, None]: @wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: GenericArgs.args, **kwargs: GenericArgs.kwargs) -> None: try: return func(*args, **kwargs) except ConnectionClosedError: @@ -37,6 +57,15 @@ def wrapper(*args, **kwargs): return wrapper +DataWatchFunc = Union[ + Callable[[Optional[bytes], Optional[ZnodeStat]], Optional[bool]], + Callable[ + [Optional[bytes], Optional[ZnodeStat], Optional[WatchedEvent]], + Optional[bool], + ], +] + + class DataWatch(object): """Watches a node for data updates and calls the specified function each time it changes @@ -88,7 +117,39 @@ def my_func(data, stat, event): """ - def __init__(self, client, path, func=None, *args, **kwargs): + @overload + def __init__( + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + ): + ... + + @overload + @deprecated( + "Passing additional arguments to DataWatch is deprecated. " + "ignore_missing_node is now assumed to be True by default, and the " + "event will be sent if the function can handle receiving it" + ) + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): + ... + + def __init__( # type: ignore[misc] + self, + client: KazooClient, + path: str, + func: DataWatchFunc | None = None, + *args: Any, + **kwargs: Any, + ): """Create a data watcher for a path :param client: A zookeeper client. @@ -107,7 +168,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._func = func self._stopped = False self._run_lock = client.handler.lock_object() - self._version = None + self._version: int | None = None self._retry = KazooRetry( max_tries=None, sleep_func=client.handler.sleep_func ) @@ -132,7 +193,7 @@ def __init__(self, client, path, func=None, *args, **kwargs): self._client.add_listener(self._session_watcher) self._get_data() - def __call__(self, func): + def __call__(self, func: DataWatchFunc) -> DataWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -155,16 +216,28 @@ def __call__(self, func): self._get_data() return func - def _log_func_exception(self, data, stat, event=None): + def _log_func_exception( + self, + data: bytes | None, + stat: ZnodeStat | None, + event: WatchedEvent | None = None, + ) -> None: try: # For backwards compatibility, don't send event to the # callback unless the send_event is set in constructor if not self._ever_called: self._ever_called = True try: - result = self._func(data, stat, event) + # The type ignores here are because mypy can't figure out that + # 1) self._func can't ever be None (fingers crossed) + # 2) the function can be called with 2 arguments or with 3 + # arguments (though that could possibly be done with better + # typing) + result = self._func( # type: ignore[call-arg, misc] + data, stat, event + ) except TypeError: - result = self._func(data, stat) + result = self._func(data, stat) # type: ignore[call-arg, misc] if result is False: self._stopped = True self._func = None @@ -174,7 +247,7 @@ def _log_func_exception(self, data, stat, event=None): raise @_ignore_closed - def _get_data(self, event=None): + def _get_data(self, event: WatchedEvent | None = None) -> None: # Ensure this runs one at a time, possible because the session # watcher may trigger a run with self._run_lock: @@ -183,6 +256,7 @@ def _get_data(self, event=None): initial_version = self._version + stat: ZnodeStat | None try: data, stat = self._retry( self._client.get, self._path, self._watcher @@ -210,18 +284,24 @@ def _get_data(self, event=None): if initial_version != self._version or not self._ever_called: self._log_func_exception(data, stat, event) - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: self._get_data(event=event) - def _set_watch(self, state): + def _set_watch(self, state: KazooState) -> None: with self._run_lock: self._watch_established = state - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state == KazooState.CONNECTED: self._client.handler.spawn(self._get_data) +ChildrenWatchFunc = Union[ + Callable[[List[str]], Optional[bool]], + Callable[[List[str], Optional[WatchedEvent]], Optional[bool]], +] + + class ChildrenWatch(object): """Watches a node for children updates and calls the specified function each time it changes @@ -253,11 +333,11 @@ def my_func(children): def __init__( self, - client, - path, - func=None, - allow_session_lost=True, - send_event=False, + client: KazooClient, + path: str, + func: ChildrenWatchFunc | None = None, + allow_session_lost: bool = True, + send_event: bool = False, ): """Create a children watcher for a path @@ -290,7 +370,7 @@ def __init__( self._watch_established = False self._allow_session_lost = allow_session_lost self._run_lock = client.handler.lock_object() - self._prior_children = None + self._prior_children: list[str] | None = None self._used = False # Register our session listener if we're going to resume @@ -301,7 +381,7 @@ def __init__( self._client.add_listener(self._session_watcher) self._get_children() - def __call__(self, func): + def __call__(self, func: ChildrenWatchFunc) -> ChildrenWatchFunc: """Callable version for use as a decorator :param func: Function to call initially and every time the @@ -325,7 +405,7 @@ def __call__(self, func): return func @_ignore_closed - def _get_children(self, event=None): + def _get_children(self, event: WatchedEvent | None = None) -> None: with self._run_lock: # Ensure this runs one at a time if self._stopped: return @@ -351,9 +431,16 @@ def _get_children(self, event=None): try: if self._send_event: - result = self._func(children, event) + # See comment about the type ignore here in DataWatch, + # it's the same issue where mypy can't figure out that the + # function can be called with 1 argument or with 2 + result = self._func( # type: ignore[misc] + children, event # type: ignore[call-arg] + ) else: - result = self._func(children) + result = self._func( # type: ignore[misc] + children # type: ignore[call-arg] + ) if result is False: self._stopped = True self._func = None @@ -363,11 +450,11 @@ def _get_children(self, event=None): log.exception(exc) raise - def _watcher(self, event): + def _watcher(self, event: WatchedEvent) -> None: if event.type != "NONE": self._get_children(event) - def _session_watcher(self, state): + def _session_watcher(self, state: KazooState) -> None: if state in (KazooState.LOST, KazooState.SUSPENDED): self._watch_established = False elif ( @@ -408,14 +495,16 @@ class PatientChildrenWatch(object): """ - def __init__(self, client, path, time_boundary=30): + def __init__( + self, client: KazooClient, path: str, time_boundary: float = 30 + ): self.client = client self.path = path - self.children = [] + self.children: list[str] = [] self.time_boundary = time_boundary self.children_changed = client.handler.event_object() - def start(self): + def start(self) -> IAsyncResult: """Begin the watching process asynchronously :returns: An :class:`~kazoo.interfaces.IAsyncResult` instance @@ -427,7 +516,7 @@ def start(self): self.client.handler.spawn(self._inner_start) return asy - def _inner_start(self): + def _inner_start(self) -> None: try: while True: async_result = self.client.handler.async_result() @@ -447,6 +536,8 @@ def _inner_start(self): except Exception as exc: self.asy.set_exception(exc) - def _children_watcher(self, async_result, event): + def _children_watcher( + self, async_result: IAsyncResult, event: WatchedEvent + ) -> None: self.children_changed.set() async_result.set(time.monotonic()) diff --git a/kazoo/retry.py b/kazoo/retry.py index fb9e8fc7b..9e4e0c9a8 100644 --- a/kazoo/retry.py +++ b/kazoo/retry.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import logging import random import time +from typing import Any, Callable, TypeVar from kazoo.exceptions import ( ConnectionClosedError, @@ -10,7 +13,6 @@ SessionExpiredError, ) - log = logging.getLogger(__name__) @@ -37,18 +39,20 @@ class KazooRetry(object): EXPIRED_EXCEPTIONS = (SessionExpiredError,) + RETRY_RETURN = TypeVar("RETRY_RETURN") + def __init__( self, - max_tries=1, - delay=0.1, - backoff=2, - max_jitter=0.4, - max_delay=60.0, - ignore_expire=True, - sleep_func=time.sleep, - deadline=None, - interrupt=None, - ): + max_tries: int | None = 1, + delay: float = 0.1, + backoff: int = 2, + max_jitter: float = 0.4, + max_delay: float = 60.0, + ignore_expire: bool = True, + sleep_func: Callable[[float], None] = time.sleep, + deadline: float | None = None, + interrupt: Callable[[], bool] | None = None, + ) -> None: """Create a :class:`KazooRetry` instance for retrying function calls. @@ -81,20 +85,22 @@ def __init__( self._attempts = 0 self._cur_delay = delay self.deadline = deadline - self._cur_stoptime = None + self._cur_stoptime: float | None = None self.sleep_func = sleep_func - self.retry_exceptions = self.RETRY_EXCEPTIONS + self.retry_exceptions: tuple[ + type[Exception], ... + ] = self.RETRY_EXCEPTIONS self.interrupt = interrupt if ignore_expire: self.retry_exceptions += self.EXPIRED_EXCEPTIONS - def reset(self): + def reset(self) -> None: """Reset the attempt counter""" self._attempts = 0 self._cur_delay = self.delay self._cur_stoptime = None - def copy(self): + def copy(self) -> KazooRetry: """Return a clone of this retry manager""" obj = KazooRetry( max_tries=self.max_tries, @@ -109,7 +115,12 @@ def copy(self): obj.retry_exceptions = self.retry_exceptions return obj - def __call__(self, func, *args, **kwargs): + def __call__( + self, + func: Callable[..., RETRY_RETURN], + *args: Any, + **kwargs: Any, + ) -> RETRY_RETURN: """Call a function with arguments until it completes without throwing a Kazoo exception diff --git a/kazoo/security.py b/kazoo/security.py index 683994451..1b383795f 100644 --- a/kazoo/security.py +++ b/kazoo/security.py @@ -1,14 +1,19 @@ """Kazoo Security""" + +from __future__ import annotations + from base64 import b64encode -from collections import namedtuple import hashlib +from typing import NamedTuple # Represents a Zookeeper ID and ACL object -Id = namedtuple("Id", "scheme id") +class Id(NamedTuple): + scheme: str + id: str -class ACL(namedtuple("ACL", "perms id")): +class ACL(NamedTuple): """An ACL for a Zookeeper Node An ACL object is created by using an :class:`Id` object along with @@ -17,8 +22,11 @@ class ACL(namedtuple("ACL", "perms id")): the desired scheme, id, and permissions. """ + perms: int + id: Id + @property - def acl_list(self): + def acl_list(self) -> list[str]: perms = [] if self.perms & Permissions.ALL == Permissions.ALL: perms.append("ALL") @@ -35,7 +43,7 @@ def acl_list(self): perms.append("ADMIN") return perms - def __repr__(self): + def __repr__(self) -> str: return "ACL(perms=%r, acl_list=%s, id=%r)" % ( self.perms, self.acl_list, @@ -62,7 +70,7 @@ class Permissions(object): READ_ACL_UNSAFE = [ACL(Permissions.READ, ANYONE_ID_UNSAFE)] -def make_digest_acl_credential(username, password): +def make_digest_acl_credential(username: str, password: str) -> str: """Create a SHA1 digest credential. .. note:: @@ -80,15 +88,15 @@ def make_digest_acl_credential(username, password): def make_acl( - scheme, - credential, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + scheme: str, + credential: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Given a scheme and credential, return an :class:`ACL` object appropriate for use with Kazoo. @@ -131,15 +139,15 @@ def make_acl( def make_digest_acl( - username, - password, - read=False, - write=False, - create=False, - delete=False, - admin=False, - all=False, -): + username: str, + password: str, + read: bool = False, + write: bool = False, + create: bool = False, + delete: bool = False, + admin: bool = False, + all: bool = False, +) -> ACL: """Create a digest ACL for Zookeeper with the given permissions This method combines :meth:`make_digest_acl_credential` and diff --git a/kazoo/testing/common.py b/kazoo/testing/common.py index 4f702e42b..6f97c547e 100644 --- a/kazoo/testing/common.py +++ b/kazoo/testing/common.py @@ -18,7 +18,7 @@ # # You should have received a copy of the GNU Lesser General Public License # along with txzookeeper. If not, see . - +from __future__ import annotations import code from collections import namedtuple @@ -37,16 +37,19 @@ import OpenSSL import jks +from types import FrameType +from typing import Any, Iterator log = logging.getLogger(__name__) -def debug(sig, frame): +def debug(sig: int, frame: FrameType | None) -> None: """Interrupt running process, and provide a python prompt for interactive debugging.""" d = {"_frame": frame} # Allow access to frame object. - d.update(frame.f_globals) # Unless shadowed by global - d.update(frame.f_locals) + if frame is not None: + d.update(frame.f_globals) # Unless shadowed by global + d.update(frame.f_locals) i = code.InteractiveConsole(d) message = "Signal received : entering python shell.\nTraceback:\n" @@ -54,7 +57,7 @@ def debug(sig, frame): i.interact(message) -def listen(): +def listen() -> None: if os.name != "nt": # SIGUSR1 is not supported on Windows signal.signal(signal.SIGUSR1, debug) # Register handler @@ -62,7 +65,7 @@ def listen(): listen() -def to_java_compatible_path(path): +def to_java_compatible_path(path: str) -> str: if os.name == "nt": path = path.replace("\\", "/") return path @@ -85,14 +88,14 @@ class ManagedZooKeeper(object): def __init__( self, - software_path, - server_info, - peers=(), - classpath=None, - configuration_entries=(), - java_system_properties=(), - jaas_config=None, - ssl_configuration=None, + software_path: str, + server_info: ServerInfo, + peers: list[ServerInfo], + classpath: str, + configuration_entries: list[str], + java_system_properties: list[str], + jaas_config: str | None = None, + ssl_configuration: dict[str, Any] | None = None, ): """Define the ZooKeeper test instance. @@ -104,8 +107,8 @@ def __init__( self.server_info = server_info self.host = "127.0.0.1" self.peers = peers - self.working_path = tempfile.mkdtemp() - self._running = False + self.working_path: str = tempfile.mkdtemp() + self._running: bool = False self.configuration_entries = configuration_entries self.java_system_properties = java_system_properties self.jaas_config = jaas_config @@ -113,7 +116,7 @@ def __init__( ssl_configuration if ssl_configuration is not None else {} ) - def run(self): + def run(self) -> None: """Run the ZooKeeper instance under a temporary directory. Writes ZK log messages to zookeeper.log in the current directory. @@ -257,7 +260,7 @@ def run(self): self._running = True @property - def classpath(self): + def classpath(self) -> str: """Get the classpath necessary to run ZooKeeper.""" if self._classpath: @@ -290,28 +293,28 @@ def classpath(self): return os.pathsep.join(jars) @property - def address(self): + def address(self) -> str: """Get the address of the ZooKeeper instance.""" return "%s:%s" % (self.host, self.client_port) @property - def secure_address(self): + def secure_address(self) -> str: """Get the address of the SSL ZooKeeper instance.""" return "%s:%s" % (self.host, self.secure_client_port) @property - def running(self): + def running(self) -> bool: return self._running @property - def client_port(self): + def client_port(self) -> Any: return self.server_info.client_port @property - def secure_client_port(self): + def secure_client_port(self) -> Any: return self.server_info.secure_client_port - def reset(self): + def reset(self) -> None: """Stop the zookeeper instance, cleaning out its on disk-data.""" self.stop() shutil.rmtree(os.path.join(self.working_path, "data"), True) @@ -319,7 +322,7 @@ def reset(self): with open(os.path.join(self.working_path, "data", "myid"), "w") as fh: fh.write(str(self.server_info.server_id)) - def stop(self): + def stop(self) -> None: """Stop the Zookeeper instance, retaining on disk state.""" if not self.running: return @@ -335,14 +338,14 @@ def stop(self): ) self._running = False - def destroy(self): + def destroy(self) -> None: """Stop the ZooKeeper instance and destroy its on disk-state""" # called by at exit handler, reimport to avoid cleanup race. self.stop() shutil.rmtree(self.working_path, True) - def get_logs(self, num_lines=100): + def get_logs(self, num_lines: int = 100) -> list[str]: log_path = pathlib.Path(self.working_path, "zookeeper.log") if log_path.exists(): log_file = log_path.open("r") @@ -354,24 +357,24 @@ def get_logs(self, num_lines=100): class ZookeeperCluster(object): def __init__( self, - install_path=None, - classpath=None, - size=3, - port_offset=20000, - observer_start_id=-1, - configuration_entries=(), - java_system_properties=(), - jaas_config=None, + install_path: str, + classpath: str, + size: int, + port_offset: int, + observer_start_id: int, + configuration_entries: list[str], + java_system_properties: list[str], + jaas_config: str | None, ): self._install_path = install_path self._classpath = classpath self._servers = [] - self._ssl_configuration = {} + self._ssl_configuration: dict[str, Any] = {} self.perform_ssl_certs_generation() # Calculate ports and peer group port = port_offset - peers = [] + peers: list[ServerInfo] = [] for i in range(size): server_id = i + 1 @@ -408,13 +411,13 @@ def __init__( ) ) - def __getitem__(self, k): + def __getitem__(self, k: int) -> ManagedZooKeeper: return self._servers[k] - def __iter__(self): + def __iter__(self) -> Iterator[ManagedZooKeeper]: return iter(self._servers) - def start(self): + def start(self) -> None: # Zookeeper client expresses a preference for either lower ports or # lexicographical ordering of hosts, to ensure that all servers have a # chance to startup, start them in reverse order. @@ -427,26 +430,26 @@ def start(self): time.sleep(2) - def stop(self): + def stop(self) -> None: for server in self: server.stop() self._servers = [] - def terminate(self): + def terminate(self) -> None: for server in self: server.destroy() - def reset(self): + def reset(self) -> None: for server in self: server.reset() - def get_logs(self): + def get_logs(self) -> list[str]: logs = [] for server in self: logs += server.get_logs() return logs - def perform_ssl_certs_generation(self): + def perform_ssl_certs_generation(self) -> None: if self._ssl_configuration: return @@ -542,7 +545,7 @@ def perform_ssl_certs_generation(self): "keystore": keystore, } - def get_ssl_client_configuration(self): + def get_ssl_client_configuration(self) -> dict[str, Any]: if not self._ssl_configuration: raise RuntimeError("SSL not configured yet.") return { diff --git a/kazoo/testing/harness.py b/kazoo/testing/harness.py index bb77f0717..10542c99a 100644 --- a/kazoo/testing/harness.py +++ b/kazoo/testing/harness.py @@ -1,10 +1,17 @@ """Kazoo testing harnesses""" +from __future__ import annotations + import atexit import logging import os import uuid import unittest +from typing import Any, Callable, Literal, cast, TYPE_CHECKING + +if TYPE_CHECKING: + import kazoo.interfaces + from kazoo.client import KazooClient from kazoo.exceptions import KazooException from kazoo.protocol.connection import _CONNECTION_DROP, _SESSION_EXPIRED @@ -13,8 +20,8 @@ log = logging.getLogger(__name__) -CLUSTER = None -CLUSTER_CONF = None +CLUSTER: ZookeeperCluster | None = None +CLUSTER_CONF: dict[str, Any] | None = None CLUSTER_DEFAULTS = { "ZOOKEEPER_PORT_OFFSET": 20000, "ZOOKEEPER_CLUSTER_SIZE": 3, @@ -24,7 +31,8 @@ MAX_INIT_TRIES = 5 -def get_global_cluster(): +# FIXME use a typeddict for cluster conf and cluster defaults +def get_global_cluster() -> ZookeeperCluster: global CLUSTER, CLUSTER_CONF cluster_conf = { k: os.environ.get(k, CLUSTER_DEFAULTS.get(k)) @@ -47,16 +55,22 @@ def get_global_cluster(): CLUSTER.terminate() CLUSTER = None # Create a new cluster - ZK_HOME = cluster_conf.get("ZOOKEEPER_PATH") - ZK_CLASSPATH = cluster_conf.get("ZOOKEEPER_CLASSPATH") - ZK_PORT_OFFSET = int(cluster_conf.get("ZOOKEEPER_PORT_OFFSET")) - ZK_CLUSTER_SIZE = int(cluster_conf.get("ZOOKEEPER_CLUSTER_SIZE")) - ZK_VERSION = cluster_conf.get("ZOOKEEPER_VERSION") - if "-" in ZK_VERSION: + ZK_HOME = cast(str, cluster_conf.get("ZOOKEEPER_PATH")) + ZK_CLASSPATH = cast(str, cluster_conf.get("ZOOKEEPER_CLASSPATH")) + ZK_PORT_OFFSET = int( # type: ignore[call-overload] + cluster_conf.get("ZOOKEEPER_PORT_OFFSET") + ) + ZK_CLUSTER_SIZE = int( # type: ignore[call-overload] + cluster_conf.get("ZOOKEEPER_CLUSTER_SIZE") + ) + ZK_VERSION_STR = cast(str, cluster_conf.get("ZOOKEEPER_VERSION")) + if "-" in ZK_VERSION_STR: # Ignore pre-release markers like -alpha - ZK_VERSION = ZK_VERSION.split("-")[0] - ZK_VERSION = tuple([int(n) for n in ZK_VERSION.split(".")]) - ZK_OBSERVER_START_ID = int(cluster_conf.get("ZOOKEEPER_OBSERVER_START_ID")) + ZK_VERSION_STR = ZK_VERSION_STR.split("-")[0] + ZK_VERSION = tuple([int(n) for n in ZK_VERSION_STR.split(".")]) + ZK_OBSERVER_START_ID = int( # type: ignore[call-overload] + cluster_conf.get("ZOOKEEPER_OBSERVER_START_ID") + ) assert ZK_HOME or ZK_CLASSPATH or ZK_VERSION, ( "Either ZOOKEEPER_PATH or ZOOKEEPER_CLASSPATH or " @@ -65,8 +79,8 @@ def get_global_cluster(): ) if ZK_VERSION >= (3, 5): - ZOOKEEPER_LOCAL_SESSION_RO = cluster_conf.get( - "ZOOKEEPER_LOCAL_SESSION_RO" + ZOOKEEPER_LOCAL_SESSION_RO = cast( + str, cluster_conf.get("ZOOKEEPER_LOCAL_SESSION_RO") ) additional_configuration_entries = [ "4lw.commands.whitelist=*", @@ -140,69 +154,78 @@ class KazooTestHarness(unittest.TestCase): Example:: class MyTestCase(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: self.setup_zookeeper() # additional test setup - def tearDown(self): + def tearDown(self)-> None: self.teardown_zookeeper() - def test_something(self): + def test_something(self) -> None: something_that_needs_a_kazoo_client(self.client) - def test_something_else(self): + def test_something_else(self) -> None: something_that_needs_zk_servers(self.servers) """ DEFAULT_CLIENT_TIMEOUT = 15 - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any): super(KazooTestHarness, self).__init__(*args, **kw) - self.client = None - self._clients = [] + self._client: KazooClient | None = None + self._clients: list[KazooClient] = [] @property - def cluster(self): + def cluster(self) -> ZookeeperCluster: return get_global_cluster() - def log(self, level, msg, *args, **kwargs): + @property + def client(self) -> KazooClient: + assert self._client is not None + return self._client + + def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None: log.log(level, msg, *args, **kwargs) @property - def servers(self): + def servers(self) -> str: return ",".join([s.address for s in self.cluster]) @property - def secure_servers(self): + def secure_servers(self) -> str: return ",".join([s.secure_address for s in self.cluster]) - def _get_nonchroot_client(self): + def _get_nonchroot_client(self) -> KazooClient: c = KazooClient(self.servers) self._clients.append(c) return c - def _get_client(self, **client_options): + def _get_client(self, **client_options: Any) -> KazooClient: if "timeout" not in client_options: client_options["timeout"] = self.DEFAULT_CLIENT_TIMEOUT c = KazooClient(self.hosts, **client_options) self._clients.append(c) return c - def lose_connection(self, event_factory): + def lose_connection( + self, event_factory: Callable[[], kazoo.interfaces.Event] + ) -> None: """Force client to lose connection with server""" self.__break_connection( _CONNECTION_DROP, KazooState.SUSPENDED, event_factory ) - def expire_session(self, event_factory): + def expire_session( + self, event_factory: Callable[[], kazoo.interfaces.Event] + ) -> None: """Force ZK to expire a client session""" self.__break_connection( _SESSION_EXPIRED, KazooState.LOST, event_factory ) - def setup_zookeeper(self, **client_options): + def setup_zookeeper(self, **client_options: Any) -> None: """Create a ZK cluster and chrooted :class:`KazooClient` The cluster will only be created on the first invocation and won't be @@ -242,11 +265,11 @@ def setup_zookeeper(self, **client_options): self.hosts = self.secure_servers + namespace else: self.hosts = self.servers + namespace - self.client = self._get_client(**client_options) + self._client = self._get_client(**client_options) self.client.start() self.client.ensure_path("/") - def teardown_zookeeper(self): + def teardown_zookeeper(self) -> None: """Reset and cleanup the zookeeper cluster that was started.""" while self._clients: c = self._clients.pop() @@ -256,23 +279,29 @@ def teardown_zookeeper(self): log.exception("Failed stopping client %s", c) finally: c.close() - self.client = None - - def __break_connection(self, break_event, expected_state, event_factory): + self._client = None + + def __break_connection( + self, + break_event: object, + expected_state: KazooState, + event_factory: Callable[[], kazoo.interfaces.Event], + ) -> None: """Break ZooKeeper connection using the specified event.""" lost = event_factory() safe = event_factory() - def watch_loss(state): + def watch_loss(state: KazooState) -> Literal[True] | None: if state == expected_state: lost.set() elif lost.is_set() and state == KazooState.CONNECTED: safe.set() return True + return None self.client.add_listener(watch_loss) - self.client._call(break_event, None) + self.client._call(break_event, None) # type: ignore[arg-type] lost.wait(5) if not lost.is_set(): @@ -286,14 +315,14 @@ def watch_loss(state): class KazooTestCase(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: self.setup_zookeeper() - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: cluster = get_global_cluster() if cluster is not None: cluster.terminate() diff --git a/kazoo/tests/conftest.py b/kazoo/tests/conftest.py index 931fd84f0..94da635aa 100644 --- a/kazoo/tests/conftest.py +++ b/kazoo/tests/conftest.py @@ -1,9 +1,11 @@ import logging +from typing import Any + log = logging.getLogger(__name__) -def pytest_exception_interact(node, call, report): +def pytest_exception_interact(node: Any, call: Any, report: Any) -> None: try: cluster = node._testcase.cluster log.error("Zookeeper cluster logs:") diff --git a/kazoo/tests/test_barrier.py b/kazoo/tests/test_barrier.py index 5f79e861d..f77c31696 100644 --- a/kazoo/tests/test_barrier.py +++ b/kazoo/tests/test_barrier.py @@ -1,35 +1,37 @@ +from __future__ import annotations + import threading from kazoo.testing import KazooTestCase class KazooBarrierTests(KazooTestCase): - def test_barrier_not_exist(self): + def test_barrier_not_exist(self) -> None: b = self.client.Barrier("/some/path") assert b.wait() - def test_barrier_exists(self): + def test_barrier_exists(self) -> None: b = self.client.Barrier("/some/path") b.create() assert not b.wait(0) b.remove() assert b.wait() - def test_remove_nonexistent_barrier(self): + def test_remove_nonexistent_barrier(self) -> None: b = self.client.Barrier("/some/path") assert not b.remove() class KazooDoubleBarrierTests(KazooTestCase): - def test_basic_barrier(self): + def test_basic_barrier(self) -> None: b = self.client.DoubleBarrier("/some/path", 1) assert not b.participating b.enter() assert b.participating - b.leave() + b.leave() # type: ignore[unreachable] assert not b.participating - def test_two_barrier(self): + def test_two_barrier(self) -> None: av = threading.Event() ev = threading.Event() bv = threading.Event() @@ -37,14 +39,14 @@ def test_two_barrier(self): b1 = self.client.DoubleBarrier("/some/path", 2) b2 = self.client.DoubleBarrier("/some/path", 2) - def make_barrier_one(): + def make_barrier_one() -> None: b1.enter() ev.set() release_all.wait() b1.leave() ev.set() - def make_barrier_two(): + def make_barrier_two() -> None: bv.wait() b2.enter() av.set() @@ -65,7 +67,7 @@ def make_barrier_two(): av.wait() ev.wait() assert b1.participating - assert b2.participating + assert b2.participating # type: ignore[unreachable] av.clear() ev.clear() @@ -78,7 +80,7 @@ def make_barrier_two(): t1.join() t2.join() - def test_three_barrier(self): + def test_three_barrier(self) -> None: av = threading.Event() ev = threading.Event() bv = threading.Event() @@ -87,14 +89,14 @@ def test_three_barrier(self): b2 = self.client.DoubleBarrier("/some/path", 3) b3 = self.client.DoubleBarrier("/some/path", 3) - def make_barrier_one(): + def make_barrier_one() -> None: b1.enter() ev.set() release_all.wait() b1.leave() ev.set() - def make_barrier_two(): + def make_barrier_two() -> None: bv.wait() b2.enter() av.set() @@ -119,7 +121,7 @@ def make_barrier_two(): av.wait() assert b1.participating - assert b2.participating + assert b2.participating # type: ignore[unreachable] assert b3.participating av.clear() @@ -135,7 +137,7 @@ def make_barrier_two(): t1.join() t2.join() - def test_barrier_existing_parent_node(self): + def test_barrier_existing_parent_node(self) -> None: b = self.client.DoubleBarrier("/some/path", 1) assert b.participating is False self.client.create("/some", ephemeral=True) @@ -143,7 +145,7 @@ def test_barrier_existing_parent_node(self): b.enter() assert b.participating is False - def test_barrier_existing_node(self): + def test_barrier_existing_node(self) -> None: b = self.client.DoubleBarrier("/some", 1) assert b.participating is False self.client.ensure_path(b.path) @@ -151,4 +153,4 @@ def test_barrier_existing_node(self): # the barrier will re-use an existing node b.enter() assert b.participating is True - b.leave() + b.leave() # type: ignore[unreachable] diff --git a/kazoo/tests/test_build.py b/kazoo/tests/test_build.py index 01dbf8737..7b36ddf13 100644 --- a/kazoo/tests/test_build.py +++ b/kazoo/tests/test_build.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import pytest @@ -6,14 +8,14 @@ class TestBuildEnvironment(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) if not os.environ.get("CI"): pytest.skip("Only run build config tests on CI.") - def test_zookeeper_version(self): - server_version = self.client.server_version() - server_version = ".".join([str(i) for i in server_version]) + def test_zookeeper_version(self) -> None: + server_version1 = self.client.server_version() + server_version = ".".join([str(i) for i in server_version1]) env_version = os.environ.get("ZOOKEEPER_VERSION") if env_version: if "-" in env_version: diff --git a/kazoo/tests/test_cache.py b/kazoo/tests/test_cache.py index 7251db643..a5bf2d787 100644 --- a/kazoo/tests/test_cache.py +++ b/kazoo/tests/test_cache.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import gc import importlib import sys import uuid +from typing import Any, TYPE_CHECKING +from queue import Queue from unittest.mock import patch, call, Mock import pytest @@ -11,6 +15,11 @@ from kazoo.exceptions import KazooException from kazoo.recipe.cache import TreeCache, TreeNode, TreeEvent +if TYPE_CHECKING: + from kazoo.handlers.gevent import SequentialGeventHandler + from kazoo.handlers.eventlet import SequentialEventletHandler + from kazoo.handlers.threading import SequentialThreadingHandler + class KazooAdaptiveHandlerTestCase(KazooTestHarness): HANDLERS = ( @@ -19,15 +28,22 @@ class KazooAdaptiveHandlerTestCase(KazooTestHarness): ("kazoo.handlers.threading", "SequentialThreadingHandler"), ) - def setUp(self): + def setUp(self) -> None: self.handler = self.choose_an_installed_handler() self.setup_zookeeper(handler=self.handler) - def tearDown(self): + def tearDown(self) -> None: self.handler = None self.teardown_zookeeper() - def choose_an_installed_handler(self): + def choose_an_installed_handler( + self, + ) -> ( + SequentialGeventHandler + | SequentialEventletHandler + | SequentialThreadingHandler + | None + ): for handler_module, handler_class in self.HANDLERS: if ( handler_module == "kazoo.handlers.gevent" @@ -40,39 +56,60 @@ def choose_an_installed_handler(self): except ImportError: continue else: - return cls() + # FIXME Should be no-any-return but hound is a dog + return cls() # type: ignore raise ImportError("No available handler") class KazooTreeCacheTests(KazooAdaptiveHandlerTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooTreeCacheTests, self).setUp() - self._event_queue = self.client.handler.queue_impl() + self._event_queue: Queue[TreeEvent] = self.client.handler.queue_impl() self._error_queue = self.client.handler.queue_impl() - self.path = None - self.cache = None + self._path: str | None = None + self._cache: TreeCache | None = None - def tearDown(self): + def tearDown(self) -> None: if not self._error_queue.empty(): try: raise self._error_queue.get() except FakeException: pass - if self.cache is not None: - self.cache.close() - self.cache = None + if self._cache is not None: + self._cache.close() + self._cache = None super(KazooTreeCacheTests, self).tearDown() - def make_cache(self): - if self.cache is None: - self.path = "/" + uuid.uuid4().hex - self.cache = TreeCache(self.client, self.path) - self.cache.listen(lambda event: self._event_queue.put(event)) - self.cache.listen_fault(lambda error: self._error_queue.put(error)) - self.cache.start() - return self.cache - - def wait_cache(self, expect=None, since=None, timeout=10): + def make_cache(self) -> TreeCache: + if self._cache is None: + self._path = "/" + uuid.uuid4().hex + self._cache = TreeCache(self.client, self.path) + self._cache.listen(lambda event: self._event_queue.put(event)) + self._cache.listen_fault( + lambda error: self._error_queue.put(error) + ) + self._cache.start() + return self._cache + + # FIXME This is entirely for the purpose of minimising code changes. + # Calling make_cache twice should be an error and the return value + # should be used, not stored. + @property + def cache(self) -> TreeCache: + assert self._cache is not None + return self._cache + + @property + def path(self) -> str: + assert self._path is not None + return self._path + + def wait_cache( + self, + expect: int | None = None, + since: int | None = None, + timeout: float = 10, + ) -> TreeEvent | None: started = since is None while True: event = self._event_queue.get(timeout=timeout) @@ -83,25 +120,25 @@ def wait_cache(self, expect=None, since=None, timeout=10): if event.event_type == since: started = True if expect is None: - return + return None - def spy_client(self, method_name): + def spy_client(self, method_name: str) -> Any: method = getattr(self.client, method_name) return patch.object(self.client, method_name, wraps=method) - def _wait_gc(self): + def _wait_gc(self) -> None: # trigger switching on some coroutine handlers self.client.handler.sleep_func(0.1) completion_queue = getattr(self.handler, "completion_queue", None) if completion_queue is not None: - while not self.client.handler.completion_queue.empty(): + while not completion_queue.empty(): self.client.handler.sleep_func(0.1) for gen in range(3): gc.collect(gen) - def count_tree_node(self): + def count_tree_node(self) -> int: # inspect GC and count tree nodes for checking memory leak for retry in range(10): result = set() @@ -112,28 +149,29 @@ def count_tree_node(self): return list(result)[0] raise RuntimeError("could not count refs exactly") - def test_start(self): + def test_start(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) stat = self.client.exists(self.path) + assert stat is not None assert stat.version == 0 assert self.cache._state == TreeCache.STATE_STARTED assert self.cache._root._state == TreeNode.STATE_LIVE - def test_start_started(self): + def test_start_started(self) -> None: self.make_cache() with pytest.raises(KazooException): self.cache.start() - def test_start_closed(self): + def test_start_closed(self) -> None: self.make_cache() self.cache.close() with pytest.raises(KazooException): self.cache.start() - def test_close(self): + def test_close(self) -> None: assert self.count_tree_node() == 0 self.make_cache() @@ -192,11 +230,13 @@ def test_close(self): == stub_child_watcher ) + # FIXME This looks pointless at best. + self._cache = None + # should not be any leaked memory (tree node) here - self.cache = None assert self.count_tree_node() == 0 - def test_delete_operation(self): + def test_delete_operation(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) @@ -225,12 +265,13 @@ def test_delete_operation(self): # should not be any leaked memory (tree node) here assert self.count_tree_node() == 1 - def test_children_operation(self): + def test_children_operation(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/test_children", b"test_children_1") event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_type == TreeEvent.NODE_ADDED assert event.event_data.path == self.path + "/test_children" assert event.event_data.data == b"test_children_1" @@ -238,6 +279,7 @@ def test_children_operation(self): self.client.set(self.path + "/test_children", b"test_children_2") event = self.wait_cache(TreeEvent.NODE_UPDATED) + assert event is not None assert event.event_type == TreeEvent.NODE_UPDATED assert event.event_data.path == self.path + "/test_children" assert event.event_data.data == b"test_children_2" @@ -245,18 +287,20 @@ def test_children_operation(self): self.client.delete(self.path + "/test_children") event = self.wait_cache(TreeEvent.NODE_REMOVED) + assert event is not None assert event.event_type == TreeEvent.NODE_REMOVED assert event.event_data.path == self.path + "/test_children" assert event.event_data.data == b"test_children_2" assert event.event_data.stat.version == 1 - def test_subtree_operation(self): + def test_subtree_operation(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo/bar/baz", makepath=True) for relative_path in ("/foo", "/foo/bar", "/foo/bar/baz"): event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_type == TreeEvent.NODE_ADDED assert event.event_data.path == self.path + relative_path assert event.event_data.data == b"" @@ -265,10 +309,11 @@ def test_subtree_operation(self): self.client.delete(self.path + "/foo", recursive=True) for relative_path in ("/foo/bar/baz", "/foo/bar", "/foo"): event = self.wait_cache(TreeEvent.NODE_REMOVED) + assert event is not None assert event.event_type == TreeEvent.NODE_REMOVED assert event.event_data.path == self.path + relative_path - def test_get_data(self): + def test_get_data(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo/bar/baz", b"@", makepath=True) @@ -277,19 +322,27 @@ def test_get_data(self): self.wait_cache(TreeEvent.NODE_ADDED) with patch.object(cache, "_client"): # disable any remote operation - assert cache.get_data(self.path).data == b"" - assert cache.get_data(self.path).stat.version == 0 - - assert cache.get_data(self.path + "/foo").data == b"" - assert cache.get_data(self.path + "/foo").stat.version == 0 - - assert cache.get_data(self.path + "/foo/bar").data == b"" - assert cache.get_data(self.path + "/foo/bar").stat.version == 0 - - assert cache.get_data(self.path + "/foo/bar/baz").data == b"@" - assert cache.get_data(self.path + "/foo/bar/baz").stat.version == 0 - - def test_get_children(self): + node = cache.get_data(self.path) + assert node is not None + assert node.data == b"" + assert node.stat.version == 0 + + node = cache.get_data(self.path + "foo") + assert node is not None + assert node.data == b"" + assert node.stat.version == 0 + + node = cache.get_data(self.path + "foo/bar") + assert node is not None + assert node.data == b"" + assert node.stat.version == 0 + + node = cache.get_data(self.path + "foo/bar/baz") + assert node is not None + assert node.data == b"@" + assert node.stat.version == 0 + + def test_get_children(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo/bar/baz", b"@", makepath=True) @@ -307,38 +360,39 @@ def test_get_children(self): assert cache.get_children(self.path + "/foo") == frozenset(["bar"]) assert cache.get_children(self.path) == frozenset(["foo"]) - def test_get_data_out_of_tree(self): + def test_get_data_out_of_tree(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with pytest.raises(ValueError): self.cache.get_data("/out_of_tree") - def test_get_children_out_of_tree(self): + def test_get_children_out_of_tree(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with pytest.raises(ValueError): self.cache.get_children("/out_of_tree") - def test_get_data_no_node(self): + def test_get_data_no_node(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, "_client"): # disable any remote operation assert cache.get_data(self.path + "/non_exists") is None - def test_get_children_no_node(self): + def test_get_children_no_node(self) -> None: cache = self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) with patch.object(cache, "_client"): # disable any remote operation assert cache.get_children(self.path + "/non_exists") is None - def test_session_reconnected(self): + def test_session_reconnected(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) self.client.create(self.path + "/foo") event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_data.path == self.path + "/foo" with self.spy_client("get_async") as get_data: @@ -382,13 +436,14 @@ def test_session_reconnected(self): any_order=True, ) - def test_root_recreated(self): + def test_root_recreated(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) # remove root node self.client.delete(self.path) event = self.wait_cache(TreeEvent.NODE_REMOVED) + assert event is not None assert event.event_type == TreeEvent.NODE_REMOVED assert event.event_data.data == b"" assert event.event_data.path == self.path @@ -397,6 +452,7 @@ def test_root_recreated(self): # re-create root node self.client.ensure_path(self.path) event = self.wait_cache(TreeEvent.NODE_ADDED) + assert event is not None assert event.event_type == TreeEvent.NODE_ADDED assert event.event_data.data == b"" assert event.event_data.path == self.path @@ -406,7 +462,7 @@ def test_root_recreated(self): "unexpected outstanding ops %r" % self.cache._outstanding_ops ) - def test_exception_handler(self): + def test_exception_handler(self) -> None: error_value = FakeException() error_handler = Mock() @@ -419,7 +475,7 @@ def test_exception_handler(self): self.cache.close() error_handler.assert_called_once_with(error_value) - def test_exception_suppressed(self): + def test_exception_suppressed(self) -> None: self.make_cache() self.wait_cache(since=TreeEvent.INITIALIZED) diff --git a/kazoo/tests/test_client.py b/kazoo/tests/test_client.py index 3f1748c4d..bc7a78e01 100644 --- a/kazoo/tests/test_client.py +++ b/kazoo/tests/test_client.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import socket import tempfile @@ -5,11 +7,14 @@ import time import uuid import unittest + +from typing import Any from unittest.mock import Mock, MagicMock, patch import pytest -from kazoo.testing import KazooTestCase + +from kazoo.client import KazooClient from kazoo.exceptions import ( AuthFailedError, BadArgumentsError, @@ -19,27 +24,45 @@ ConnectionLoss, InvalidACLError, NoAuthError, + NoChildrenForEphemeralsError, NoNodeError, NodeExistsError, + RolledBackError, SessionExpiredError, KazooException, ) +from kazoo.interfaces import IAsyncResult +from kazoo.protocol.states import KazooState, KeeperState, WatchedEvent +from kazoo.handlers.threading import ( + SequentialThreadingHandler, + KazooTimeoutError, +) from kazoo.protocol.connection import _CONNECTION_DROP -from kazoo.protocol.states import KeeperState, KazooState +from kazoo.retry import KazooRetry +from kazoo.security import ( + make_digest_acl_credential, + CREATOR_ALL_ACL, + make_digest_acl, + ACL, + OPEN_ACL_UNSAFE, +) + +from kazoo.testing import KazooTestCase +from kazoo.testing.common import ManagedZooKeeper from kazoo.tests.util import CI_ZK_VERSION class TestClientTransitions(KazooTestCase): @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def test_connection_and_disconnection(self): + def test_connection_and_disconnection(self) -> None: states = [] rc = threading.Event() @self.client.add_listener - def listener(state): + def listener(state: KazooState) -> None: states.append(state) if state == KazooState.CONNECTED: rc.set() @@ -61,18 +84,14 @@ def listener(state): class TestClientConstructor(unittest.TestCase): - def _makeOne(self, *args, **kw): - from kazoo.client import KazooClient - + def _makeOne(self, *args: Any, **kw: Any) -> KazooClient: return KazooClient(*args, **kw) - def test_invalid_handler(self): - from kazoo.handlers.threading import SequentialThreadingHandler - + def test_invalid_handler(self) -> None: with pytest.raises(ConfigurationError): self._makeOne(handler=SequentialThreadingHandler) - def test_chroot(self): + def test_chroot(self) -> None: assert self._makeOne(hosts="127.0.0.1:2181/").chroot == "" assert self._makeOne(hosts="127.0.0.1:2181/a").chroot == "/a" assert self._makeOne(hosts="127.0.0.1/a").chroot == "/a" @@ -82,35 +101,31 @@ def test_chroot(self): == "/a/b" ) - def test_connection_timeout(self): - from kazoo.handlers.threading import KazooTimeoutError - + def test_connection_timeout(self) -> None: client = self._makeOne(hosts="127.0.0.1:9") assert client.handler.timeout_exception is KazooTimeoutError with pytest.raises(KazooTimeoutError): client.start(0.1) - def test_ordered_host_selection(self): + def test_ordered_host_selection(self) -> None: client = self._makeOne( hosts="127.0.0.1:9,127.0.0.2:9/a", randomize_hosts=False ) hosts = [h for h in client.hosts] assert hosts == [("127.0.0.1", 9), ("127.0.0.2", 9)] - def test_invalid_hostname(self): + def test_invalid_hostname(self) -> None: client = self._makeOne(hosts="nosuchhost/a") timeout = client.handler.timeout_exception with pytest.raises(timeout): client.start(0.1) - def test_another_invalid_hostname(self): + def test_another_invalid_hostname(self) -> None: with pytest.raises(ValueError): self._makeOne(hosts="/nosuchhost/a") - def test_retry_options_dict(self): - from kazoo.retry import KazooRetry - + def test_retry_options_dict(self) -> None: client = self._makeOne( command_retry=dict(max_tries=99), connection_retry=dict(delay=99) ) @@ -121,12 +136,10 @@ def test_retry_options_dict(self): class TestAuthentication(KazooTestCase): - def _makeAuth(self, *args, **kwargs): - from kazoo.security import make_digest_acl - + def _makeAuth(self, *args: Any, **kwargs: Any) -> ACL: return make_digest_acl(*args, **kwargs) - def test_auth(self): + def test_auth(self) -> None: username = uuid.uuid4().hex password = uuid.uuid4().hex @@ -162,7 +175,8 @@ def test_auth(self): eve.stop() eve.close() - def test_connect_auth(self): + def test_connect_auth(self) -> None: + username = uuid.uuid4().hex password = uuid.uuid4().hex @@ -184,7 +198,7 @@ def test_connect_auth(self): client.stop() client.close() - def test_unicode_auth(self): + def test_unicode_auth(self) -> None: username = r"xe4/\hm" password = r"/\xe4hm" digest_auth = "%s:%s" % (username, password) @@ -217,17 +231,19 @@ def test_unicode_auth(self): eve.stop() eve.close() - def test_invalid_auth(self): + def test_invalid_auth(self) -> None: client = self._get_client() client.start() with pytest.raises(TypeError): - client.add_auth("digest", ("user", "pass")) + client.add_auth( + "digest", ("user", "pass") # type: ignore[arg-type] + ) with pytest.raises(TypeError): - client.add_auth(None, ("user", "pass")) + client.add_auth(None, ("user", "pass")) # type: ignore[arg-type] - def test_async_auth(self): + def test_async_auth(self) -> None: client = self._get_client() client.start() username = uuid.uuid4().hex @@ -236,7 +252,7 @@ def test_async_auth(self): result = client.add_auth_async("digest", digest_auth) assert result.get() is True - def test_async_auth_failure(self): + def test_async_auth_failure(self) -> None: client = self._get_client() client.start() username = uuid.uuid4().hex @@ -246,10 +262,11 @@ def test_async_auth_failure(self): with pytest.raises(AuthFailedError): client.add_auth("unknown-scheme", digest_auth) - def test_add_auth_on_reconnect(self): + def test_add_auth_on_reconnect(self) -> None: client = self._get_client() client.start() client.add_auth("digest", "jsmith:jsmith") + assert client._connection._socket is not None client._connection._socket.shutdown(socket.SHUT_RDWR) while not client.connected: time.sleep(0.1) @@ -258,14 +275,14 @@ def test_add_auth_on_reconnect(self): class TestConnection(KazooTestCase): @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() @staticmethod - def make_condition(): + def make_condition() -> threading.Condition: return threading.Condition() - def test_chroot_warning(self): + def test_chroot_warning(self) -> None: k = self._get_nonchroot_client() k.chroot = "abba" try: @@ -275,12 +292,11 @@ def test_chroot_warning(self): finally: k.stop() - def test_session_expire(self): - from kazoo.protocol.states import KazooState + def test_session_expire(self) -> None: cv = self.make_event() - def watch_events(event): + def watch_events(event: KazooState) -> None: if event == KazooState.LOST: cv.set() @@ -289,17 +305,15 @@ def watch_events(event): cv.wait(3) assert cv.is_set() - def test_bad_session_expire(self): - from kazoo.protocol.states import KazooState + def test_bad_session_expire(self) -> None: cv = self.make_event() ab = self.make_event() - def watch_events(event): + def watch_events(event: KazooState) -> None: if event == KazooState.LOST: ab.set() raise Exception("oops") - cv.set() self.client.add_listener(watch_events) self.expire_session(self.make_event) @@ -308,13 +322,12 @@ def watch_events(event): cv.wait(0.5) assert not cv.is_set() - def test_state_listener(self): - from kazoo.protocol.states import KazooState + def test_state_listener(self) -> None: states = [] condition = self.make_condition() - def listener(state): + def listener(state: KazooState) -> None: with condition: states.append(state) condition.notify_all() @@ -331,18 +344,17 @@ def listener(state): assert len(states) == 1 assert states[0] == KazooState.CONNECTED - def test_invalid_listener(self): + def test_invalid_listener(self) -> None: with pytest.raises(ConfigurationError): - self.client.add_listener(15) + self.client.add_listener(15) # type: ignore[arg-type] - def test_listener_only_called_on_real_state_change(self): - from kazoo.protocol.states import KazooState + def test_listener_only_called_on_real_state_change(self) -> None: assert self.client.state == KazooState.CONNECTED called = [False] condition = self.make_event() - def listener(state): + def listener(state: KazooState) -> None: called[0] = True condition.set() @@ -351,7 +363,7 @@ def listener(state): condition.wait(3) assert called[0] is False - def test_no_connection(self): + def test_no_connection(self) -> None: client = self.client client.stop() assert client.connected is False @@ -360,12 +372,12 @@ def test_no_connection(self): with pytest.raises(ConnectionClosedError): client.exists("/") - def test_close_connecting_connection(self): + def test_close_connecting_connection(self) -> None: client = self.client client.stop() ev = self.make_event() - def close_on_connecting(state): + def close_on_connecting(state: KazooState) -> None: if state in (KazooState.CONNECTED, KazooState.LOST): ev.set() @@ -385,23 +397,23 @@ def close_on_connecting(state): with pytest.raises(ConnectionClosedError): self.client.create("/foobar") - def test_double_start(self): + def test_double_start(self) -> None: assert self.client.connected is True self.client.start() assert self.client.connected is True - def test_double_stop(self): + def test_double_stop(self) -> None: self.client.stop() assert self.client.connected is False self.client.stop() assert self.client.connected is False - def test_restart(self): + def test_restart(self) -> None: assert self.client.connected is True self.client.restart() assert self.client.connected is True - def test_closed(self): + def test_closed(self) -> None: client = self.client client.stop() @@ -433,13 +445,13 @@ def test_closed(self): client._state = oldstate client._connection._write_sock = None - def test_watch_trigger_expire(self): + def test_watch_trigger_expire(self) -> None: client = self.client cv = self.make_event() client.create("/test", b"") - def test_watch(event): + def test_watch(event: WatchedEvent) -> None: cv.set() client.get("/test/", watch=test_watch) @@ -450,17 +462,11 @@ def test_watch(event): class TestClient(KazooTestCase): - def _makeOne(self, *args): - from kazoo.handlers.threading import SequentialThreadingHandler - + def _makeOne(self, *args: Any) -> SequentialThreadingHandler: return SequentialThreadingHandler(*args) - def _getKazooState(self): - from kazoo.protocol.states import KazooState + def test_server_version_retries_fail(self) -> None: - return KazooState - - def test_server_version_retries_fail(self): client = self.client side_effects = [ "", @@ -468,38 +474,38 @@ def test_server_version_retries_fail(self): "zookeeper.version=1.", "zookeeper.ver", ] - client.command = MagicMock() + client.command = MagicMock() # type: ignore[method-assign] client.command.side_effect = side_effects with pytest.raises(KazooException): client.server_version(retries=len(side_effects) - 1) - def test_server_version_retries_eventually_ok(self): + def test_server_version_retries_eventually_ok(self) -> None: client = self.client actual_version = "zookeeper.version=1.2" side_effects = [] for i in range(0, len(actual_version) + 1): side_effects.append(actual_version[0:i]) - client.command = MagicMock() + client.command = MagicMock() # type: ignore[method-assign] client.command.side_effect = side_effects assert client.server_version(retries=len(side_effects) - 1) == (1, 2) - def test_client_id(self): + def test_client_id(self) -> None: client_id = self.client.client_id assert type(client_id) is tuple # make sure password is of correct length assert len(client_id[1]) == 16 - def test_connected(self): + def test_connected(self) -> None: client = self.client assert client.connected - def test_create(self): + def test_create(self) -> None: client = self.client path = client.create("/1") assert path == "/1" assert client.exists("/1") - def test_create_on_broken_connection(self): + def test_create_on_broken_connection(self) -> None: client = self.client client.start() @@ -517,34 +523,34 @@ def test_create_on_broken_connection(self): with pytest.raises(ConnectionClosedError): client.create("/closedpath", b"bar") - def test_create_null_data(self): + def test_create_null_data(self) -> None: client = self.client client.create("/nulldata", None) value, _ = client.get("/nulldata") assert value is None - def test_create_empty_string(self): + def test_create_empty_string(self) -> None: client = self.client client.create("/empty", b"") value, _ = client.get("/empty") assert value == b"" - def test_create_unicode_path(self): + def test_create_unicode_path(self) -> None: client = self.client path = client.create("/ascii") assert path == "/ascii" path = client.create("/\xe4hm") assert path == "/\xe4hm" - def test_create_async_returns_unchrooted_path(self): + def test_create_async_returns_unchrooted_path(self) -> None: client = self.client path = client.create_async("/1").get() assert path == "/1" - def test_create_invalid_path(self): + def test_create_invalid_path(self) -> None: client = self.client with pytest.raises(TypeError): - client.create(("a",)) + client.create(("a",)) # type:ignore[call-overload] with pytest.raises(ValueError): client.create(".") with pytest.raises(ValueError): @@ -554,36 +560,34 @@ def test_create_invalid_path(self): with pytest.raises(BadArgumentsError): client.create("/b\x1e") - def test_create_invalid_arguments(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_create_invalid_arguments(self) -> None: single_acl = OPEN_ACL_UNSAFE[0] client = self.client with pytest.raises(TypeError): - client.create("a", acl="all") + client.create("a", acl="all") # type: ignore[arg-type] with pytest.raises(TypeError): - client.create("a", acl=single_acl) + client.create("a", acl=single_acl) # type: ignore[arg-type] with pytest.raises(TypeError): - client.create("a", value=["a"]) + client.create("a", value=["a"]) # type: ignore[call-overload] with pytest.raises(TypeError): - client.create("a", ephemeral="yes") + client.create("a", ephemeral="yes") # type: ignore[call-overload] with pytest.raises(TypeError): - client.create("a", sequence="yes") + client.create("a", sequence="yes") # type: ignore[call-overload] with pytest.raises(TypeError): - client.create("a", makepath="yes") + client.create("a", makepath="yes") # type: ignore[call-overload] - def test_create_value(self): + def test_create_value(self) -> None: client = self.client client.create("/1", b"bytes") data, stat = client.get("/1") assert data == b"bytes" - def test_create_unicode_value(self): + def test_create_unicode_value(self) -> None: client = self.client with pytest.raises(TypeError): - client.create("/1", "\xe4hm") + client.create("/1", "\xe4hm") # type: ignore[call-overload] - def test_create_large_value(self): + def test_create_large_value(self) -> None: client = self.client kb_512 = b"a" * (512 * 1024) client.create("/1", kb_512) @@ -592,49 +596,41 @@ def test_create_large_value(self): with pytest.raises(ConnectionLoss): client.create("/2", mb_2) - def test_create_acl_duplicate(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_create_acl_duplicate(self) -> None: single_acl = OPEN_ACL_UNSAFE[0] client = self.client client.create("/1", acl=[single_acl, single_acl]) acls, stat = client.get_acls("/1") # ZK >3.4 removes duplicate ACL entries - if CI_ZK_VERSION: - version = CI_ZK_VERSION - else: - version = client.server_version() + version = CI_ZK_VERSION if CI_ZK_VERSION else client.server_version() assert len(acls) == 1 if version > (3, 4) else 2 - def test_create_acl_empty_list(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_create_acl_empty_list(self) -> None: client = self.client client.create("/1", acl=[]) acls, stat = client.get_acls("/1") assert acls == OPEN_ACL_UNSAFE - def test_version_no_connection(self): + def test_version_no_connection(self) -> None: self.client.stop() with pytest.raises(ConnectionLoss): self.client.server_version() - def test_create_ephemeral(self): + def test_create_ephemeral(self) -> None: client = self.client client.create("/1", b"ephemeral", ephemeral=True) + assert client.client_id is not None data, stat = client.get("/1") assert data == b"ephemeral" assert stat.ephemeralOwner == client.client_id[0] - def test_create_no_ephemeral(self): + def test_create_no_ephemeral(self) -> None: client = self.client client.create("/1", b"val1") data, stat = client.get("/1") assert not stat.ephemeralOwner - def test_create_ephemeral_no_children(self): - from kazoo.exceptions import NoChildrenForEphemeralsError - + def test_create_ephemeral_no_children(self) -> None: client = self.client client.create("/1", b"ephemeral", ephemeral=True) with pytest.raises(NoChildrenForEphemeralsError): @@ -642,7 +638,7 @@ def test_create_ephemeral_no_children(self): with pytest.raises(NoChildrenForEphemeralsError): client.create("/1/2", b"val1", ephemeral=True) - def test_create_sequence(self): + def test_create_sequence(self) -> None: client = self.client client.create("/folder") path = client.create("/folder/a", b"sequence", sequence=True) @@ -652,7 +648,7 @@ def test_create_sequence(self): path3 = client.create("/folder/", b"sequence", sequence=True) assert path3 == "/folder/0000000002" - def test_create_ephemeral_sequence(self): + def test_create_ephemeral_sequence(self) -> None: basepath = "/" + uuid.uuid4().hex realpath = self.client.create( basepath, b"sandwich", sequence=True, ephemeral=True @@ -661,7 +657,7 @@ def test_create_ephemeral_sequence(self): data, stat = self.client.get(realpath) assert data == b"sandwich" - def test_create_makepath(self): + def test_create_makepath(self) -> None: self.client.create("/1/2", b"val1", makepath=True) data, stat = self.client.get("/1/2") assert data == b"val1" @@ -673,10 +669,7 @@ def test_create_makepath(self): with pytest.raises(NodeExistsError): self.client.create("/1/2/3/4/5", b"val2", makepath=True) - def test_create_makepath_incompatible_acls(self): - from kazoo.client import KazooClient - from kazoo.security import make_digest_acl_credential, CREATOR_ALL_ACL - + def test_create_makepath_incompatible_acls(self) -> None: credential = make_digest_acl_credential("username", "password") alt_client = KazooClient( self.cluster[0].address + self.client.chroot, @@ -695,7 +688,7 @@ def test_create_makepath_incompatible_acls(self): alt_client.delete("/", recursive=True) alt_client.stop() - def test_create_no_makepath(self): + def test_create_no_makepath(self) -> None: with pytest.raises(NoNodeError): self.client.create("/1/2", b"val1") with pytest.raises(NoNodeError): @@ -705,15 +698,13 @@ def test_create_no_makepath(self): with pytest.raises(NoNodeError): self.client.create("/1/2/3/4", b"val1", makepath=False) - def test_create_exists(self): - from kazoo.exceptions import NodeExistsError - + def test_create_exists(self) -> None: client = self.client path = client.create("/1") with pytest.raises(NodeExistsError): client.create(path) - def test_create_stat(self): + def test_create_stat(self) -> None: if CI_ZK_VERSION: version = CI_ZK_VERSION else: @@ -726,7 +717,7 @@ def test_create_stat(self): assert data == b"bytes" assert stat1 == stat2 - def test_create_get_set(self): + def test_create_get_set(self) -> None: nodepath = "/" + uuid.uuid4().hex self.client.create(nodepath, b"sandwich", ephemeral=True) @@ -749,20 +740,20 @@ def test_create_get_set(self): assert newstat.children_count == stat.numChildren assert newstat.children_version == stat.cversion - def test_get_invalid_arguments(self): + def test_get_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.get(("a", "b")) + client.get(("a", "b")) # type: ignore[arg-type] with pytest.raises(TypeError): - client.get("a", watch=True) + client.get("a", watch=True) # type: ignore[arg-type] - def test_bad_argument(self): + def test_bad_argument(self) -> None: client = self.client client.ensure_path("/1") with pytest.raises(TypeError): - self.client.set("/1", 1) + self.client.set("/1", 1) # type: ignore[arg-type] - def test_ensure_path(self): + def test_ensure_path(self) -> None: client = self.client client.ensure_path("/1/2") assert client.exists("/1/2") @@ -770,13 +761,13 @@ def test_ensure_path(self): client.ensure_path("/1/2/3/4") assert client.exists("/1/2/3/4") - def test_sync(self): + def test_sync(self) -> None: client = self.client assert client.sync("/") == "/" # Albeit surprising, you can sync anything, even what does not exist. assert client.sync("/not_there") == "/not_there" - def test_exists(self): + def test_exists(self) -> None: nodepath = "/" + uuid.uuid4().hex exists = self.client.exists(nodepath) @@ -791,18 +782,18 @@ def test_exists(self): exists = self.client.exists(multi_node_nonexistent) assert exists is None - def test_exists_invalid_arguments(self): + def test_exists_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.exists(("a", "b")) + client.exists(("a", "b")) # type: ignore[arg-type] with pytest.raises(TypeError): - client.exists("a", watch=True) + client.exists("a", watch=True) # type: ignore[arg-type] - def test_exists_watch(self): + def test_exists_watch(self) -> None: nodepath = "/" + uuid.uuid4().hex event = self.client.handler.event_object() - def w(watch_event): + def w(watch_event: WatchedEvent) -> None: assert watch_event.path == nodepath event.set() @@ -814,12 +805,12 @@ def w(watch_event): event.wait(1) assert event.is_set() is True - def test_exists_watcher_exception(self): + def test_exists_watcher_exception(self) -> None: nodepath = "/" + uuid.uuid4().hex event = self.client.handler.event_object() # if the watcher throws an exception, all we can really do is log it - def w(watch_event): + def w(watch_event: WatchedEvent) -> None: assert watch_event.path == nodepath event.set() @@ -833,7 +824,7 @@ def w(watch_event): event.wait(1) assert event.is_set() is True - def test_create_delete(self): + def test_create_delete(self) -> None: nodepath = "/" + uuid.uuid4().hex self.client.create(nodepath, b"zzz") @@ -843,9 +834,7 @@ def test_create_delete(self): exists = self.client.exists(nodepath) assert exists is None - def test_get_acls(self): - from kazoo.security import make_digest_acl - + def test_get_acls(self) -> None: user = "user" passw = "pass" acl = make_digest_acl(user, passw, all=True) @@ -857,14 +846,12 @@ def test_get_acls(self): finally: client.delete("/a") - def test_get_acls_invalid_arguments(self): + def test_get_acls_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.get_acls(("a", "b")) - - def test_set_acls(self): - from kazoo.security import make_digest_acl + client.get_acls(("a", "b")) # type: ignore[arg-type] + def test_set_acls(self) -> None: user = "user" passw = "pass" acl = make_digest_acl(user, passw, all=True) @@ -877,34 +864,30 @@ def test_set_acls(self): finally: client.delete("/a") - def test_set_acls_empty(self): + def test_set_acls_empty(self) -> None: client = self.client client.create("/a") with pytest.raises(InvalidACLError): client.set_acls("/a", []) - def test_set_acls_no_node(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_set_acls_no_node(self) -> None: client = self.client with pytest.raises(NoNodeError): client.set_acls("/a", OPEN_ACL_UNSAFE) - def test_set_acls_invalid_arguments(self): - from kazoo.security import OPEN_ACL_UNSAFE - + def test_set_acls_invalid_arguments(self) -> None: single_acl = OPEN_ACL_UNSAFE[0] client = self.client with pytest.raises(TypeError): - client.set_acls(("a", "b"), ()) + client.set_acls(("a", "b"), ()) # type: ignore[arg-type] with pytest.raises(TypeError): - client.set_acls("a", single_acl) + client.set_acls("a", single_acl) # type: ignore[arg-type] with pytest.raises(TypeError): - client.set_acls("a", "all") + client.set_acls("a", "all") # type: ignore[arg-type] with pytest.raises(TypeError): - client.set_acls("a", [single_acl], "V1") + client.set_acls("a", [single_acl], "V1") # type: ignore[arg-type] - def test_set(self): + def test_set(self) -> None: client = self.client client.create("a", b"first") stat = client.set("a", b"second") @@ -912,38 +895,38 @@ def test_set(self): assert data == b"second" assert stat == stat2 - def test_set_null_data(self): + def test_set_null_data(self) -> None: client = self.client client.create("/nulldata", b"not none") client.set("/nulldata", None) value, _ = client.get("/nulldata") assert value is None - def test_set_empty_string(self): + def test_set_empty_string(self) -> None: client = self.client client.create("/empty", b"not empty") client.set("/empty", b"") value, _ = client.get("/empty") assert value == b"" - def test_set_invalid_arguments(self): + def test_set_invalid_arguments(self) -> None: client = self.client client.create("a", b"first") with pytest.raises(TypeError): - client.set(("a", "b"), b"value") + client.set(("a", "b"), b"value") # type: ignore[arg-type] with pytest.raises(TypeError): - client.set("a", ["v", "w"]) + client.set("a", ["v", "w"]) # type: ignore[arg-type] with pytest.raises(TypeError): - client.set("a", b"value", "V1") + client.set("a", b"value", "V1") # type: ignore[arg-type] - def test_delete(self): + def test_delete(self) -> None: client = self.client client.ensure_path("/a/b") assert "b" in client.get_children("a") client.delete("/a/b") assert "b" not in client.get_children("a") - def test_delete_recursive(self): + def test_delete_recursive(self) -> None: client = self.client client.ensure_path("/a/b/c") client.ensure_path("/a/b/d") @@ -951,17 +934,17 @@ def test_delete_recursive(self): client.delete("/a/b/c", recursive=True) assert "b" not in client.get_children("a") - def test_delete_invalid_arguments(self): + def test_delete_invalid_arguments(self) -> None: client = self.client client.ensure_path("/a/b") with pytest.raises(TypeError): - client.delete("/a/b", recursive="all") + client.delete("/a/b", recursive="all") # type: ignore[arg-type] with pytest.raises(TypeError): - client.delete(("a", "b")) + client.delete(("a", "b")) # type: ignore[arg-type] with pytest.raises(TypeError): - client.delete("/a/b", version="V1") + client.delete("/a/b", version="V1") # type: ignore[arg-type] - def test_get_children(self): + def test_get_children(self) -> None: client = self.client client.ensure_path("/a/b/c") client.ensure_path("/a/b/d") @@ -969,7 +952,7 @@ def test_get_children(self): assert set(client.get_children("/a/b")) == set(["c", "d"]) assert client.get_children("/a/b/c") == [] - def test_get_children2(self): + def test_get_children2(self) -> None: client = self.client client.ensure_path("/a/b") children, stat = client.get_children("/a", include_data=True) @@ -977,7 +960,7 @@ def test_get_children2(self): assert children == ["b"] assert stat2.version == stat.version - def test_get_children2_many_nodes(self): + def test_get_children2_many_nodes(self) -> None: client = self.client client.ensure_path("/a/b") client.ensure_path("/a/c") @@ -987,31 +970,30 @@ def test_get_children2_many_nodes(self): assert set(children) == set(["b", "c", "d"]) assert stat2.version == stat.version - def test_get_children_no_node(self): + def test_get_children_no_node(self) -> None: client = self.client with pytest.raises(NoNodeError): client.get_children("/none") with pytest.raises(NoNodeError): client.get_children("/none", include_data=True) - def test_get_children_invalid_path(self): + def test_get_children_invalid_path(self) -> None: client = self.client with pytest.raises(ValueError): client.get_children("../a") - def test_get_children_invalid_arguments(self): + def test_get_children_invalid_arguments(self) -> None: client = self.client with pytest.raises(TypeError): - client.get_children(("a", "b")) + client.get_children(("a", "b")) # type: ignore[call-overload] with pytest.raises(TypeError): - client.get_children("a", watch=True) + client.get_children("a", watch=True) # type: ignore[call-overload] with pytest.raises(TypeError): - client.get_children("a", include_data="yes") - - def test_invalid_auth(self): - from kazoo.exceptions import AuthFailedError - from kazoo.protocol.states import KeeperState + client.get_children( # type: ignore[call-overload] + "a", include_data="yes" + ) + def test_invalid_auth(self) -> None: client = self.client client.stop() client._state = KeeperState.AUTH_FAILED @@ -1019,15 +1001,10 @@ def test_invalid_auth(self): with pytest.raises(AuthFailedError): client.get("/") - def test_client_state(self): - from kazoo.protocol.states import KeeperState - + def test_client_state(self) -> None: assert self.client.client_state == KeeperState.CONNECTED - def test_update_host_list(self): - from kazoo.client import KazooClient - from kazoo.protocol.states import KeeperState - + def test_update_host_list(self) -> None: hosts = self.cluster[0].address # create a client with only one server in its list client = KazooClient(hosts=hosts) @@ -1049,8 +1026,9 @@ def test_update_host_list(self): self.cluster[0].run() # utility for test_request_queuing* - def _make_request_queuing_client(self): - from kazoo.client import KazooClient + def _make_request_queuing_client( + self, + ) -> tuple[KazooClient, ManagedZooKeeper]: server = self.cluster[0] handler = self._makeOne() @@ -1071,11 +1049,17 @@ def _make_request_queuing_client(self): return client, server # utility for test_request_queuing* - def _request_queuing_common(self, client, server, path, expire_session): + def _request_queuing_common( + self, + client: KazooClient, + server: ManagedZooKeeper, + path: str, + expire_session: bool, + ) -> IAsyncResult: ev_suspended = client.handler.event_object() ev_connected = client.handler.event_object() - def listener(state): + def listener(state: KazooState) -> None: if state == KazooState.SUSPENDED: ev_suspended.set() elif state == KazooState.CONNECTED: @@ -1117,7 +1101,7 @@ def listener(state): return result - def test_request_queuing_session_recovered(self): + def test_request_queuing_session_recovered(self) -> None: path = "/" + uuid.uuid4().hex client, server = self._make_request_queuing_client() @@ -1131,7 +1115,7 @@ def test_request_queuing_session_recovered(self): finally: client.stop() - def test_request_queuing_session_expired(self): + def test_request_queuing_session_expired(self) -> None: path = "/" + uuid.uuid4().hex client, server = self._make_request_queuing_client() @@ -1148,7 +1132,7 @@ def test_request_queuing_session_expired(self): class TestSSLClient(KazooTestCase): - def setUp(self): + def setUp(self) -> None: if CI_ZK_VERSION and CI_ZK_VERSION < (3, 5): pytest.skip("Must use Zookeeper 3.5 or above") ssl_path = tempfile.mkdtemp() @@ -1171,7 +1155,7 @@ def setUp(self): use_ssl=True, keyfile=key_path, certfile=cert_path, ca=cacert_path ) - def test_create(self): + def test_create(self) -> None: client = self.client path = client.create("/1") assert path == "/1" @@ -1194,7 +1178,7 @@ def test_create(self): class TestClientTransactions(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) skip = False if CI_ZK_VERSION and CI_ZK_VERSION < (3, 4): @@ -1208,7 +1192,7 @@ def setUp(self): if skip: pytest.skip("Must use Zookeeper 3.4 or above") - def test_basic_create(self): + def test_basic_create(self) -> None: t = self.client.transaction() t.create("/freddy") t.create("/fred", ephemeral=True) @@ -1218,7 +1202,7 @@ def test_basic_create(self): assert results[0] == "/freddy" assert results[2].startswith("/smith0") is True - def test_bad_creates(self): + def test_bad_creates(self) -> None: args_list = [ (True,), ("/smith", 0), @@ -1230,11 +1214,9 @@ def test_bad_creates(self): for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.create(*args) - - def test_default_acl(self): - from kazoo.security import make_digest_acl + t.create(*args) # type: ignore[arg-type] + def test_default_acl(self) -> None: username = uuid.uuid4().hex password = uuid.uuid4().hex @@ -1249,14 +1231,14 @@ def test_default_acl(self): results = t.commit() assert results[0] == "/freddy" - def test_basic_delete(self): + def test_basic_delete(self) -> None: self.client.create("/fred") t = self.client.transaction() t.delete("/fred") results = t.commit() assert results[0] is True - def test_bad_deletes(self): + def test_bad_deletes(self) -> None: args_list = [ (True,), ("/smith", "woops"), @@ -1265,9 +1247,9 @@ def test_bad_deletes(self): for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.delete(*args) + t.delete(*args) # type: ignore[arg-type] - def test_set(self): + def test_set(self) -> None: self.client.create("/fred", b"01") t = self.client.transaction() t.set_data("/fred", b"oops") @@ -1275,15 +1257,15 @@ def test_set(self): res = self.client.get("/fred") assert res[0] == b"oops" - def test_bad_sets(self): + def test_bad_sets(self) -> None: args_list = [(42, 52), ("/smith", False), ("/smith", b"", "oops")] for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.set_data(*args) + t.set_data(*args) # type: ignore[arg-type] - def test_check(self): + def test_check(self) -> None: self.client.create("/fred") version = self.client.get("/fred")[1].version t = self.client.transaction() @@ -1293,17 +1275,15 @@ def test_check(self): assert results[0] is True assert results[1] == "/blah" - def test_bad_checks(self): + def test_bad_checks(self) -> None: args_list = [(42, 52), ("/smith", "oops")] for args in args_list: with pytest.raises(TypeError): t = self.client.transaction() - t.check(*args) - - def test_bad_transaction(self): - from kazoo.exceptions import RolledBackError, NoNodeError + t.check(*args) # type: ignore[arg-type] + def test_bad_transaction(self) -> None: t = self.client.transaction() t.create("/fred") t.delete("/smith") @@ -1311,57 +1291,56 @@ def test_bad_transaction(self): assert results[0].__class__ == RolledBackError assert results[1].__class__ == NoNodeError - def test_bad_commit(self): + def test_bad_commit(self) -> None: t = self.client.transaction() t.committed = True with pytest.raises(ValueError): t.commit() - def test_bad_context(self): + def test_bad_context(self) -> None: with pytest.raises(TypeError): with self.client.transaction() as t: - t.check(4232) + t.check(4232) # type: ignore[arg-type,call-arg] - def test_context(self): + def test_context(self) -> None: with self.client.transaction() as t: t.create("/smith", b"32") assert self.client.get("/smith")[0] == b"32" class TestSessionCallbacks(unittest.TestCase): - def test_session_callback_states(self): - from kazoo.protocol.states import KazooState, KeeperState - from kazoo.client import KazooClient - + def test_session_callback_states(self) -> None: client = KazooClient() - client._handle = 1 client._live.set() - result = client._session_callback(KeeperState.CONNECTED) - assert result is None + client._session_callback(KeeperState.CONNECTED) # Now with stopped client._stopped.set() - result = client._session_callback(KeeperState.CONNECTED) - assert result is None + client._session_callback(KeeperState.CONNECTED) # Test several state transitions client._stopped.clear() - client.start_async = lambda: True + client.start_async = ( # type: ignore[method-assign] + lambda: threading.Event() + ) client._session_callback(KeeperState.CONNECTED) assert client.state == KazooState.CONNECTED client._session_callback(KeeperState.AUTH_FAILED) - assert client.state == KazooState.LOST + # FIXME mypy seems to be under the impression that the state can't + # change as a result of the above call, even though it can. + assert ( + client.state == KazooState.LOST # type: ignore[comparison-overlap] + ) - client._handle = 1 - client._session_callback(-250) + client._session_callback(-250) # type: ignore[unreachable] assert client.state == KazooState.SUSPENDED class TestCallbacks(KazooTestCase): - def test_async_result_callbacks_are_always_called(self): + def test_async_result_callbacks_are_always_called(self) -> None: # create a callback object callback_mock = Mock() @@ -1384,7 +1363,7 @@ def test_async_result_callbacks_are_always_called(self): class TestNonChrootClient(KazooTestCase): - def test_create(self): + def test_create(self) -> None: client = self._get_nonchroot_client() assert client.chroot == "" client.start() @@ -1393,7 +1372,7 @@ def test_create(self): client.delete(path) client.stop() - def test_unchroot(self): + def test_unchroot(self) -> None: client = self._get_nonchroot_client() client.chroot = "/a" # Unchroot'ing the chroot path should return "/" @@ -1403,7 +1382,7 @@ def test_unchroot(self): class TestReconfig(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) if CI_ZK_VERSION: @@ -1413,7 +1392,7 @@ def setUp(self): if not version or version < (3, 5): pytest.skip("Must use Zookeeper 3.5 or above") - def test_no_super_auth(self): + def test_no_super_auth(self) -> None: with pytest.raises(NoAuthError): self.client.reconfig( joining="server.999=0.0.0.0:1234:2345:observer;3456", @@ -1421,8 +1400,8 @@ def test_no_super_auth(self): new_members=None, ) - def test_add_remove_observer(self): - def free_sock_port(): + def test_add_remove_observer(self) -> None: + def free_sock_port() -> tuple[socket.socket, int]: s = socket.socket() s.bind(("", 0)) return s, s.getsockname()[1] @@ -1466,7 +1445,7 @@ def free_sock_port(): from_config=curver + 1, ) - def test_bad_input(self): + def test_bad_input(self) -> None: with pytest.raises(BadArgumentsError): self.client.reconfig( joining="some thing", leaving=None, new_members=None diff --git a/kazoo/tests/test_connection.py b/kazoo/tests/test_connection.py index 032b94bbe..0579a91bd 100644 --- a/kazoo/tests/test_connection.py +++ b/kazoo/tests/test_connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import namedtuple, deque import os import threading @@ -7,15 +9,18 @@ import struct import sys +from typing import Any, Iterable, Deque, Tuple import pytest -from kazoo.exceptions import ConnectionLoss +from kazoo.client import KazooClient +from kazoo.exceptions import ConnectionLoss, NotReadOnlyCallError +from kazoo.interfaces import FdLike from kazoo.protocol.serialization import ( Connect, int_struct, write_string, ) -from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KazooState, KeeperState from kazoo.protocol.connection import _CONNECTION_DROP from kazoo.testing import KazooTestCase from kazoo.tests.util import wait, CI_ZK_VERSION, CI @@ -24,32 +29,33 @@ class Delete(namedtuple("Delete", "path version")): type = 2 - def serialize(self): + def serialize(self) -> bytearray: b = bytearray() b.extend(write_string(self.path)) b.extend(int_struct.pack(self.version)) return b @classmethod - def deserialize(self, bytes, offset): + def deserialize(self, bytes: bytes, offset: int) -> None: raise ValueError("oh my") class TestConnectionHandler(KazooTestCase): - def test_bad_deserialization(self): + def test_bad_deserialization(self) -> None: async_object = self.client.handler.async_result() self.client._queue.append( (Delete(self.client.chroot, -1), async_object) ) + assert self.client._connection._write_sock is not None self.client._connection._write_sock.send(b"\0") with pytest.raises(ValueError): async_object.get() - def test_with_bad_sessionid(self): + def test_with_bad_sessionid(self) -> None: ev = threading.Event() - def expired(state): + def expired(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() @@ -63,7 +69,7 @@ def expired(state): finally: client.stop() - def test_connection_read_timeout(self): + def test_connection_read_timeout(self) -> None: client = self.client ev = threading.Event() path = "/" + uuid.uuid4().hex @@ -71,31 +77,33 @@ def test_connection_read_timeout(self): _select = handler.select _socket = client._connection._socket - def delayed_select(*args, **kwargs): + def delayed_select( + *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: result = _select(*args, **kwargs) if len(args[0]) == 1 and _socket in args[0]: # for any socket read, simulate a timeout return [], [], [] return result - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() client.add_listener(back) client.create(path, b"1") try: - handler.select = delayed_select + handler.select = delayed_select # type: ignore[method-assign] with pytest.raises(ConnectionLoss): client.get(path) finally: - handler.select = _select + handler.select = _select # type: ignore[method-assign] # the client reconnects automatically ev.wait(5) assert ev.is_set() assert client.get(path)[0] == b"1" - def test_connection_write_timeout(self): + def test_connection_write_timeout(self) -> None: client = self.client ev = threading.Event() path = "/" + uuid.uuid4().hex @@ -103,31 +111,33 @@ def test_connection_write_timeout(self): _select = handler.select _socket = client._connection._socket - def delayed_select(*args, **kwargs): + def delayed_select( + *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: result = _select(*args, **kwargs) if _socket in args[1]: # for any socket write, simulate a timeout return [], [], [] return result - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() client.add_listener(back) try: - handler.select = delayed_select + handler.select = delayed_select # type: ignore[method-assign] with pytest.raises(ConnectionLoss): client.create(path) finally: - handler.select = _select + handler.select = _select # type: ignore[method-assign] # the client reconnects automatically ev.wait(5) assert ev.is_set() assert client.exists(path) is None - def test_connection_deserialize_fail(self): + def test_connection_deserialize_fail(self) -> None: client = self.client ev = threading.Event() path = "/" + uuid.uuid4().hex @@ -135,14 +145,16 @@ def test_connection_deserialize_fail(self): _select = handler.select _socket = client._connection._socket - def delayed_select(*args, **kwargs): + def delayed_select( + *args: Any, **kwargs: Any + ) -> tuple[Iterable[FdLike], Iterable[FdLike], Iterable[FdLike]]: result = _select(*args, **kwargs) if _socket in args[1]: # for any socket write, simulate a timeout return [], [], [] return result - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() @@ -150,7 +162,7 @@ def back(state): deserialize_ev = threading.Event() - def bad_deserialize(_bytes, offset): + def bad_deserialize(_bytes: bytes, offset: int) -> None: deserialize_ev.set() raise struct.error() @@ -162,11 +174,11 @@ def bad_deserialize(_bytes, offset): with patch.object(Connect, "deserialize") as mock_deserialize: mock_deserialize.side_effect = bad_deserialize try: - handler.select = delayed_select + handler.select = delayed_select # type: ignore[method-assign] with pytest.raises(ConnectionLoss): client.create(path) finally: - handler.select = _select + handler.select = _select # type: ignore[method-assign] # the client reconnects automatically but the first attempt will # hit a deserialize failure. wait for that. deserialize_ev.wait(5) @@ -177,7 +189,7 @@ def bad_deserialize(_bytes, offset): assert ev.is_set() assert client.exists(path) is None - def test_connection_close(self): + def test_connection_close(self) -> None: with pytest.raises(Exception): self.client.close() self.client.stop() @@ -186,7 +198,7 @@ def test_connection_close(self): # should be able to restart self.client.start() - def test_connection_sock(self): + def test_connection_sock(self) -> None: client = self.client read_sock = client._connection._read_sock write_sock = client._connection._write_sock @@ -217,10 +229,12 @@ def test_connection_sock(self): read_sock.getsockname() write_sock.getsockname() - def test_dirty_sock(self): + def test_dirty_sock(self) -> None: client = self.client read_sock = client._connection._read_sock write_sock = client._connection._write_sock + assert read_sock is not None + assert write_sock is not None # add a stray byte to the socket and ensure that doesn't # blow up client. simulates case where some error leaves @@ -233,10 +247,10 @@ def test_dirty_sock(self): class TestConnectionDrop(KazooTestCase): - def test_connection_dropped(self): + def test_connection_dropped(self) -> None: ev = threading.Event() - def back(state): + def back(state: KazooState) -> None: if state == KazooState.CONNECTED: ev.set() @@ -245,7 +259,7 @@ def back(state): self.client.create(path) self.client.add_listener(back) result = self.client.set_async(path, b"a" * 1000 * 1024) - self.client._call(_CONNECTION_DROP, None) + self.client._call(_CONNECTION_DROP, None) # type: ignore[arg-type] with pytest.raises(ConnectionLoss): result.get() @@ -255,7 +269,7 @@ def back(state): class TestReadOnlyMode(KazooTestCase): - def setUp(self): + def setUp(self) -> None: os.environ["ZOOKEEPER_LOCAL_SESSION_RO"] = "true" self.setup_zookeeper() skip = False @@ -270,31 +284,28 @@ def setUp(self): if skip: pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.client.stop() os.environ.pop("ZOOKEEPER_LOCAL_SESSION_RO", None) - def test_read_only(self): - from kazoo.exceptions import NotReadOnlyCallError - from kazoo.protocol.states import KeeperState - + def test_read_only(self) -> None: if CI: # force some wait to make sure the data produced during the - # `setUp()` step are replicaed to all zk members + # `setUp()` step are replicated to all zk members # if not done the `get_children()` test may fail because the # node does not exist on the node that we will keep alive time.sleep(15) # do not keep the client started in the `setUp` step alive self.client.stop() client = self._get_client(connection_retry=None, read_only=True) - states = [] ev = threading.Event() - @client.add_listener - def listen(state): - states.append(state) + def listen(state: KazooState) -> bool | None: if client.client_state == KeeperState.CONNECTED_RO: ev.set() + return None + + client.add_listener(listen) client.start() try: @@ -336,7 +347,7 @@ def listen(state): class TestUnorderedXids(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(TestUnorderedXids, self).setUp() self.connection = self.client._connection @@ -345,46 +356,53 @@ def setUp(self): self._pending = self.client._pending self.client._pending = _naughty_deque() - def tearDown(self): + def tearDown(self) -> None: self.client._pending = self._pending super(TestUnorderedXids, self).tearDown() - def _get_client(self, **kwargs): + def _get_client(self, **kwargs: Any) -> KazooClient: # overrides for patching zk_loop c = KazooTestCase._get_client(self, **kwargs) self._zk_loop = c._connection.zk_loop - self._zk_loop_errors = [] - c._connection.zk_loop = self._zk_loop_func + self._zk_loop_errors: list[BaseException] = [] + c._connection.zk_loop = ( # type: ignore[method-assign] + self._zk_loop_func + ) return c - def _zk_loop_func(self, *args, **kwargs): + def _zk_loop_func(self, *args: Any, **kwargs: Any) -> None: # patched zk_loop which will catch and collect all RuntimeError try: self._zk_loop(*args, **kwargs) except RuntimeError as e: self._zk_loop_errors.append(e) - def test_xids_mismatch(self): + def test_xids_mismatch(self) -> None: from kazoo.protocol.states import KeeperState ev = threading.Event() error_stack = [] - @self.client.add_listener - def listen(state): + def listen(state: KazooState) -> bool | None: if self.client.client_state == KeeperState.CLOSED: ev.set() + return None - def log_exception(*args): + self.client.add_listener(listen) + + def log_exception(*args: Any) -> None: error_stack.append((args, sys.exc_info())) - self.connection.logger.exception = log_exception + self.connection.logger.exception = ( # type: ignore[method-assign] + log_exception # type: ignore[assignment] + ) ev.clear() with pytest.raises(RuntimeError): self.client.get_children("/") ev.wait() + self.client.remove_listener(listen) assert self.client.connected is False assert self.client.state == "LOST" assert self.client.client_state == KeeperState.CLOSED @@ -394,12 +412,13 @@ def log_exception(*args): assert exc_info[0] == RuntimeError self.client.handler.sleep_func(0.2) + assert self.connection_routine is not None assert not self.connection_routine.is_alive() assert len(self._zk_loop_errors) == 1 assert self._zk_loop_errors[0] == exc_info[1] -class _naughty_deque(deque): - def append(self, s): +class _naughty_deque(Deque[Tuple[Any, Any, int]]): + def append(self, s: Tuple[Any, Any, int]) -> None: request, async_object, xid = s - return deque.append(self, (request, async_object, xid + 1)) # +1s + deque.append(self, (request, async_object, xid + 1)) # +1s diff --git a/kazoo/tests/test_counter.py b/kazoo/tests/test_counter.py index a78667356..aff4012ba 100644 --- a/kazoo/tests/test_counter.py +++ b/kazoo/tests/test_counter.py @@ -1,16 +1,21 @@ +from __future__ import annotations + import uuid +from typing import Any + import pytest +from kazoo.recipe.counter import Counter from kazoo.testing import KazooTestCase class KazooCounterTests(KazooTestCase): - def _makeOne(self, **kw): + def _makeOne(self, **kw: Any) -> Counter: path = "/" + uuid.uuid4().hex return self.client.Counter(path, **kw) - def test_int_counter(self): + def test_int_counter(self) -> None: counter = self._makeOne() assert counter.value == 0 counter += 2 @@ -20,7 +25,7 @@ def test_int_counter(self): counter - 1 assert counter.value == -1 - def test_int_curator_counter(self): + def test_int_curator_counter(self) -> None: counter = self._makeOne(support_curator=True) assert counter.value == 0 counter += 2 @@ -36,7 +41,7 @@ def test_int_curator_counter(self): counter -= 2147483647 assert counter.value == -2147483647 - def test_float_counter(self): + def test_float_counter(self) -> None: counter = self._makeOne(default=0.0) assert counter.value == 0.0 counter += 2.1 @@ -44,16 +49,18 @@ def test_float_counter(self): counter -= 3.1 assert counter.value == -1.0 - def test_errors(self): + def test_errors(self) -> None: counter = self._makeOne() with pytest.raises(TypeError): - counter.__add__(2.1) + counter.__add__(2.1) # type: ignore[arg-type] with pytest.raises(TypeError): - counter.__add__(b"a") + counter.__add__(b"a") # type: ignore[operator] with pytest.raises(TypeError): - counter = self._makeOne(default=0.0, support_curator=True) + counter = self._makeOne( # type: ignore[arg-type] + default=0.0, support_curator=True + ) - def test_pre_post_values(self): + def test_pre_post_values(self) -> None: counter = self._makeOne() assert counter.value == 0 assert counter.pre_value is None diff --git a/kazoo/tests/test_election.py b/kazoo/tests/test_election.py index c7b1b5959..d2cb1e460 100644 --- a/kazoo/tests/test_election.py +++ b/kazoo/tests/test_election.py @@ -1,19 +1,27 @@ +from __future__ import annotations + import uuid import sys import threading +from typing import TYPE_CHECKING, cast import pytest +from kazoo.recipe.election import Election from kazoo.testing import KazooTestCase from kazoo.tests.util import wait +if TYPE_CHECKING: + from types import TracebackType + from _typeshed import OptExcInfo + class UniqueError(Exception): """Error raised only by test leader function""" class KazooElectionTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooElectionTests, self).setUp() self.path = "/" + uuid.uuid4().hex @@ -21,17 +29,19 @@ def setUp(self): # election contenders set these when elected. The exit event is set by # the test to make the leader exit. - self.leader_id = None - self.exit_event = None + self.leader_id: str | None = None + self.exit_event: threading.Event | None = None # tests set this before the event to make the leader raise an error self.raise_exception = False # set by a worker thread when an unexpected error is hit. # better way to do this? - self.thread_exc_info = None + self.thread_exc_info: OptExcInfo | None = None - def _spawn_contender(self, contender_id, election): + def _spawn_contender( + self, contender_id: str, election: Election + ) -> threading.Thread: thread = threading.Thread( target=self._election_thread, args=(contender_id, election) ) @@ -39,7 +49,7 @@ def _spawn_contender(self, contender_id, election): thread.start() return thread - def _election_thread(self, contender_id, election): + def _election_thread(self, contender_id: str, election: Election) -> None: try: election.run(self._leader_func, contender_id) except UniqueError: @@ -50,9 +60,13 @@ def _election_thread(self, contender_id, election): else: if self.raise_exception: e = Exception("expected leader func to raise exception") - self.thread_exc_info = (Exception, e, None) + self.thread_exc_info = ( + Exception, + e, + cast("TracebackType", None), + ) - def _leader_func(self, name): + def _leader_func(self, name: str) -> None: exit_event = threading.Event() with self.condition: self.exit_event = exit_event @@ -63,12 +77,13 @@ def _leader_func(self, name): if self.raise_exception: raise UniqueError("expected error in the leader function") - def _check_thread_error(self): - if self.thread_exc_info: + def _check_thread_error(self) -> None: + if self.thread_exc_info is not None: t, o, tb = self.thread_exc_info + assert t is not None raise t(o) - def test_election(self): + def test_election(self) -> None: elections = {} threads = {} for _ in range(3): @@ -105,6 +120,7 @@ def test_election(self): elections[contenders[1]].cancel() # make leader exit. third contender should be elected. + assert self.exit_event is not None self.exit_event.set() with self.condition: while self.leader_id == first_leader: @@ -137,7 +153,8 @@ def test_election(self): thread.join() self._check_thread_error() - def test_bad_func(self): + def test_bad_func(self) -> None: election = self.client.Election(self.path) + # FIXME If we're using type hints, we don't need to check this. with pytest.raises(ValueError): - election.run("not a callable") + election.run("not a callable") # type: ignore[arg-type] diff --git a/kazoo/tests/test_eventlet_handler.py b/kazoo/tests/test_eventlet_handler.py index ff2649d98..08db67684 100644 --- a/kazoo/tests/test_eventlet_handler.py +++ b/kazoo/tests/test_eventlet_handler.py @@ -1,27 +1,28 @@ +from __future__ import annotations + import contextlib import unittest +from typing import Generator, Literal + import pytest -from kazoo.client import KazooClient +from kazoo.handlers.utils import create_tcp_socket from kazoo.handlers import utils from kazoo.protocol import states as kazoo_states -from kazoo.tests import test_client -from kazoo.tests import test_lock -from kazoo.tests import util as test_util try: - import eventlet - from eventlet.green import threading + from eventlet.green import socket from kazoo.handlers import eventlet as eventlet_handler - - EVENTLET_HANDLER_AVAILABLE = True + from kazoo.handlers.eventlet import SequentialEventletHandler except ImportError: - EVENTLET_HANDLER_AVAILABLE = False + pytestmark = pytest.mark.skip(reason="eventlet not available") @contextlib.contextmanager -def start_stop_one(handler=None): +def start_stop_one( + handler: SequentialEventletHandler = None, # type: ignore[assignment] +) -> Generator[SequentialEventletHandler]: if not handler: handler = eventlet_handler.SequentialEventletHandler() handler.start() @@ -32,22 +33,17 @@ def start_stop_one(handler=None): class TestEventletHandler(unittest.TestCase): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletHandler, self).setUp() - - def test_started(self): + def test_started(self) -> None: with start_stop_one() as handler: assert handler.running is True assert len(handler._workers) != 0 assert handler.running is False - assert len(handler._workers) == 0 + assert len(handler._workers) == 0 # type: ignore[unreachable] - def test_spawn(self): + def test_spawn(self) -> None: captures = [] - def cb(): + def cb() -> None: captures.append(1) with start_stop_one() as handler: @@ -55,10 +51,10 @@ def cb(): assert len(captures) == 1 - def test_dispatch(self): + def test_dispatch(self) -> None: captures = [] - def cb(): + def cb() -> None: captures.append(1) with start_stop_one() as handler: @@ -66,10 +62,10 @@ def cb(): assert len(captures) == 1 - def test_async_link(self): - captures = [] + def test_async_link(self) -> None: + captures: list[SequentialEventletHandler] = [] - def cb(handler): + def cb(handler: SequentialEventletHandler) -> None: captures.append(handler) with start_stop_one() as handler: @@ -80,20 +76,20 @@ def cb(handler): assert len(captures) == 1 assert r.get() == 2 - def test_timeout_raising(self): + def test_timeout_raising(self) -> None: handler = eventlet_handler.SequentialEventletHandler() with pytest.raises(handler.timeout_exception): raise handler.timeout_exception("This is a timeout") - def test_async_ok(self): - captures = [] + def test_async_ok(self) -> None: + captures: list[Literal[1] | SequentialEventletHandler] = [] - def delayed(): + def delayed() -> Literal[1]: captures.append(1) return 1 - def after_delayed(handler): + def after_delayed(handler: SequentialEventletHandler) -> None: captures.append(handler) with start_stop_one() as handler: @@ -106,7 +102,7 @@ def after_delayed(handler): assert captures[0] == 1 assert r.get() == 1 - def test_get_with_no_block(self): + def test_get_with_no_block(self) -> None: handler = eventlet_handler.SequentialEventletHandler() with start_stop_one(handler): @@ -117,8 +113,8 @@ def test_get_with_no_block(self): r.set(1) assert r.get() == 1 - def test_async_exception(self): - def broken(): + def test_async_exception(self) -> None: + def broken() -> None: raise IOError("Failed") with start_stop_one() as handler: @@ -130,13 +126,11 @@ def broken(): with pytest.raises(IOError): r.get() - def test_huge_file_descriptor(self): + def test_huge_file_descriptor(self) -> None: try: import resource except ImportError: self.skipTest("resource module unavailable on this platform") - from eventlet.green import socket - from kazoo.handlers.utils import create_tcp_socket try: resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096)) @@ -154,91 +148,3 @@ def test_huge_file_descriptor(self): h.stop() for sock in socks: sock.close() - - -class TestEventletClient(test_client.TestClient): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletClient, self).setUp() - - @staticmethod - def make_event(): - return threading.Event() - - @staticmethod - def make_condition(): - return threading.Condition() - - def _makeOne(self, *args): - return eventlet_handler.SequentialEventletHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - return KazooClient(self.hosts, **kwargs) - - -class TestEventletSemaphore(test_lock.TestSemaphore): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletSemaphore, self).setUp() - - @staticmethod - def make_condition(): - return threading.Condition() - - @staticmethod - def make_event(): - return threading.Event() - - @staticmethod - def make_thread(*args, **kwargs): - return threading.Thread(*args, **kwargs) - - def _makeOne(self, *args): - return eventlet_handler.SequentialEventletHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - c = KazooClient(self.hosts, **kwargs) - try: - self._clients.append(c) - except AttributeError: - self._client = [c] - return c - - -class TestEventletLock(test_lock.KazooLockTests): - def setUp(self): - if not EVENTLET_HANDLER_AVAILABLE: - pytest.skip("eventlet handler not available.") - super(TestEventletLock, self).setUp() - - @staticmethod - def make_condition(): - return threading.Condition() - - @staticmethod - def make_event(): - return threading.Event() - - @staticmethod - def make_thread(*args, **kwargs): - return threading.Thread(*args, **kwargs) - - @staticmethod - def make_wait(): - return test_util.Wait(getsleep=(lambda: eventlet.sleep)) - - def _makeOne(self, *args): - return eventlet_handler.SequentialEventletHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - c = KazooClient(self.hosts, **kwargs) - try: - self._clients.append(c) - except AttributeError: - self._client = [c] - return c diff --git a/kazoo/tests/test_exceptions.py b/kazoo/tests/test_exceptions.py index d2fb9c6f4..ff35a3fd5 100644 --- a/kazoo/tests/test_exceptions.py +++ b/kazoo/tests/test_exceptions.py @@ -1,30 +1,34 @@ +from __future__ import annotations + from unittest import TestCase +from types import ModuleType + import pytest class ExceptionsTestCase(TestCase): - def _get(self): + def _get(self) -> ModuleType: from kazoo import exceptions return exceptions - def test_backwards_alias(self): + def test_backwards_alias(self) -> None: module = self._get() assert hasattr(module, "NoNodeException") assert module.NoNodeException is module.NoNodeError - def test_exceptions_code(self): + def test_exceptions_code(self) -> None: module = self._get() exc_8 = module.EXCEPTIONS[-8] assert isinstance(exc_8(), module.BadArgumentsError) - def test_invalid_code(self): + def test_invalid_code(self) -> None: module = self._get() with pytest.raises(RuntimeError): module.EXCEPTIONS.__getitem__(666) - def test_exceptions_construction(self): + def test_exceptions_construction(self) -> None: module = self._get() exc = module.EXCEPTIONS[-101]() assert type(exc) is module.NoNodeError diff --git a/kazoo/tests/test_gevent_handler.py b/kazoo/tests/test_gevent_handler.py index 28dd46b08..c0c98d926 100644 --- a/kazoo/tests/test_gevent_handler.py +++ b/kazoo/tests/test_gevent_handler.py @@ -1,61 +1,60 @@ +from __future__ import annotations + import unittest import sys +from typing import Any, Type import pytest -from kazoo.client import KazooClient from kazoo.exceptions import NoNodeError -from kazoo.protocol.states import Callback +from kazoo.handlers.utils import create_tcp_socket +from kazoo.protocol.states import Callback, ZnodeStat from kazoo.testing import KazooTestCase -from kazoo.tests import test_client + +try: + import gevent # NOQA: + from gevent.event import Event + from gevent.queue import Empty + from gevent import socket + from kazoo.handlers.gevent import AsyncResult, SequentialGeventHandler +except ImportError: + pytestmark = pytest.mark.skip(reason="gevent not available") @pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") class TestGeventHandler(unittest.TestCase): - def setUp(self): - try: - import gevent # NOQA - except ImportError: - pytest.skip("gevent not available.") - - def _makeOne(self, *args): - from kazoo.handlers.gevent import SequentialGeventHandler - + def _makeOne(self, *args: Any) -> SequentialGeventHandler: return SequentialGeventHandler(*args) - def _getAsync(self, *args): - from kazoo.handlers.gevent import AsyncResult - + def _getAsync(self) -> Type[AsyncResult[Any]]: return AsyncResult - def _getEvent(self): - from gevent.event import Event - + def _getEvent(self) -> Type[Event]: return Event - def test_proper_threading(self): + def test_proper_threading(self) -> None: h = self._makeOne() h.start() assert isinstance(h.event_object(), self._getEvent()) - def test_matching_async(self): + def test_matching_async(self) -> None: h = self._makeOne() h.start() async_handler = self._getAsync() assert isinstance(h.async_result(), async_handler) - def test_exception_raising(self): + def test_exception_raising(self) -> None: h = self._makeOne() with pytest.raises(h.timeout_exception): raise h.timeout_exception("This is a timeout") - def test_exception_in_queue(self): + def test_exception_in_queue(self) -> None: h = self._makeOne() h.start() ev = self._getEvent()() - def func(): + def func() -> None: ev.set() raise ValueError("bang") @@ -63,14 +62,12 @@ def func(): h.dispatch_callback(call1) ev.wait() - def test_queue_empty_exception(self): - from gevent.queue import Empty - + def test_queue_empty_exception(self) -> None: h = self._makeOne() h.start() ev = self._getEvent()() - def func(): + def func() -> None: ev.set() raise Empty() @@ -81,30 +78,22 @@ def func(): @pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") class TestBasicGeventClient(KazooTestCase): - def setUp(self): - try: - import gevent # NOQA - except ImportError: - pytest.skip("gevent not available.") + def setUp(self) -> None: KazooTestCase.setUp(self) - def _makeOne(self, *args): - from kazoo.handlers.gevent import SequentialGeventHandler - + def _makeOne(self, *args: Any) -> SequentialGeventHandler: return SequentialGeventHandler(*args) - def _getEvent(self): - from gevent.event import Event - + def _getEvent(self) -> Type[Event]: return Event - def test_start(self): + def test_start(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() assert client.state == "CONNECTED" client.stop() - def test_start_stop_double(self): + def test_start_stop_double(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() assert client.state == "CONNECTED" @@ -112,7 +101,7 @@ def test_start_stop_double(self): client.handler.stop() client.stop() - def test_basic_commands(self): + def test_basic_commands(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() assert client.state == "CONNECTED" @@ -122,22 +111,23 @@ def test_basic_commands(self): assert client.exists("/anode") is None client.stop() - def test_failures(self): + def test_failures(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() with pytest.raises(NoNodeError): client.get("/none") client.stop() - def test_data_watcher(self): + def test_data_watcher(self) -> None: client = self._get_client(handler=self._makeOne()) client.start() client.ensure_path("/some/node") ev = self._getEvent()() @client.DataWatch("/some/node") - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: ev.set() + return None ev.wait() ev.clear() @@ -145,11 +135,11 @@ def changed(d, stat): ev.wait() client.stop() - def test_huge_file_descriptor(self): - import resource - from gevent import socket - from kazoo.handlers.utils import create_tcp_socket - + def test_huge_file_descriptor(self) -> None: + try: + import resource + except ImportError: + self.skipTest("resource module unavailable on this platform") try: resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096)) except (ValueError, resource.error): @@ -166,22 +156,3 @@ def test_huge_file_descriptor(self): h.stop() for sock in socks: sock.close() - - -@pytest.mark.skipif(sys.platform == "win32", reason="does not run on windows") -class TestGeventClient(test_client.TestClient): - def setUp(self): - try: - import gevent # NOQA - except ImportError: - pytest.skip("gevent not available.") - KazooTestCase.setUp(self) - - def _makeOne(self, *args): - from kazoo.handlers.gevent import SequentialGeventHandler - - return SequentialGeventHandler(*args) - - def _get_client(self, **kwargs): - kwargs["handler"] = self._makeOne() - return KazooClient(self.hosts, **kwargs) diff --git a/kazoo/tests/test_hosts.py b/kazoo/tests/test_hosts.py index a2fae1ae0..80517d5d7 100644 --- a/kazoo/tests/test_hosts.py +++ b/kazoo/tests/test_hosts.py @@ -1,10 +1,12 @@ +from __future__ import annotations + from unittest import TestCase from kazoo.hosts import collect_hosts class HostsTestCase(TestCase): - def test_ipv4(self): + def test_ipv4(self) -> None: hosts, chroot = collect_hosts( "127.0.0.1:2181, 192.168.1.2:2181, \ 132.254.111.10:2181" @@ -26,7 +28,7 @@ def test_ipv4(self): ] assert chroot is None - def test_ipv6(self): + def test_ipv6(self) -> None: hosts, chroot = collect_hosts("[fe80::200:5aee:feaa:20a2]:2181") assert hosts == [("fe80::200:5aee:feaa:20a2", 2181)] assert chroot is None @@ -35,7 +37,7 @@ def test_ipv6(self): assert hosts == [("fe80::200:5aee:feaa:20a2", 2181)] assert chroot is None - def test_hosts_list(self): + def test_hosts_list(self) -> None: hosts, chroot = collect_hosts("zk01:2181, zk02:2181, zk03:2181") expected1 = [("zk01", 2181), ("zk02", 2181), ("zk03", 2181)] assert hosts == expected1 diff --git a/kazoo/tests/test_interrupt.py b/kazoo/tests/test_interrupt.py index ad4ae5d64..fd45ba1e0 100644 --- a/kazoo/tests/test_interrupt.py +++ b/kazoo/tests/test_interrupt.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os from sys import platform @@ -7,7 +9,7 @@ class KazooInterruptTests(KazooTestCase): - def test_interrupted_systemcall(self): + def test_interrupted_systemcall(self) -> None: """ Make sure interrupted system calls don't break the world, since we can't control what all signals our connection thread will get diff --git a/kazoo/tests/test_lease.py b/kazoo/tests/test_lease.py index 98a125606..2b8cadb39 100644 --- a/kazoo/tests/test_lease.py +++ b/kazoo/tests/test_lease.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import uuid @@ -8,18 +10,18 @@ class MockClock(object): - def __init__(self, epoch=0): + def __init__(self, epoch: float = 0): self.epoch = epoch - def forward(self, seconds): + def forward(self, seconds: float) -> None: self.epoch += seconds - def __call__(self): + def __call__(self) -> datetime.datetime: return datetime.datetime.utcfromtimestamp(self.epoch) class KazooLeaseTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooLeaseTests, self).setUp() self.client2 = self._get_client(timeout=0.8) self.client2.start() @@ -28,7 +30,7 @@ def setUp(self): self.path = "/" + uuid.uuid4().hex self.clock = MockClock(10) - def tearDown(self): + def tearDown(self) -> None: for cl in [self.client2, self.client3]: if cl.connected: cl.stop() @@ -38,7 +40,7 @@ def tearDown(self): class NonBlockingLeaseTests(KazooLeaseTests): - def test_renew(self): + def test_renew(self) -> None: # Use client convenience method here to test it at least once. Use # class directly in # other tests in order to get better IDE support. @@ -54,7 +56,7 @@ def test_renew(self): ) assert renewed_lease - def test_busy(self): + def test_busy(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -74,7 +76,7 @@ def test_busy(self): assert not foreigner_lease assert foreigner_lease.obtained is False - def test_overtake(self): + def test_overtake(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -93,7 +95,7 @@ def test_overtake(self): ) assert foreigner_lease - def test_renew_no_overtake(self): + def test_renew_no_overtake(self) -> None: lease = self.client.NonBlockingLease( self.path, datetime.timedelta(seconds=3), utcnow=self.clock ) @@ -116,7 +118,7 @@ def test_renew_no_overtake(self): ) assert not foreigner_lease - def test_overtaker_renews(self): + def test_overtaker_renews(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -145,7 +147,7 @@ def test_overtaker_renews(self): ) assert foreigner_renew - def test_overtake_refuse_first(self): + def test_overtake_refuse_first(self) -> None: lease = NonBlockingLease( self.client, self.path, @@ -173,7 +175,7 @@ def test_overtake_refuse_first(self): ) assert not first_again_lease - def test_old_version(self): + def test_old_version(self) -> None: # Skip to a future version NonBlockingLease._version += 1 lease = NonBlockingLease( @@ -199,7 +201,7 @@ def test_old_version(self): class MultiNonBlockingLeaseTest(KazooLeaseTests): - def test_1_renew(self): + def test_1_renew(self) -> None: ls = self.client.MultiNonBlockingLease( 1, self.path, datetime.timedelta(seconds=4), utcnow=self.clock ) @@ -214,7 +216,7 @@ def test_1_renew(self): ) assert ls2 - def test_1_reject(self): + def test_1_reject(self) -> None: ls = MultiNonBlockingLease( self.client, 1, @@ -234,7 +236,7 @@ def test_1_reject(self): ) assert not ls2 - def test_2_renew(self): + def test_2_renew(self) -> None: ls = MultiNonBlockingLease( self.client, 2, @@ -273,7 +275,7 @@ def test_2_renew(self): ) assert ls4 - def test_2_reject(self): + def test_2_reject(self) -> None: ls = MultiNonBlockingLease( self.client, 2, @@ -303,7 +305,7 @@ def test_2_reject(self): ) assert not ls3 - def test_2_handover(self): + def test_2_handover(self) -> None: ls = MultiNonBlockingLease( self.client, 2, diff --git a/kazoo/tests/test_lock.py b/kazoo/tests/test_lock.py index 7e5ecb775..18ac4b51d 100644 --- a/kazoo/tests/test_lock.py +++ b/kazoo/tests/test_lock.py @@ -1,15 +1,21 @@ +from __future__ import annotations + import collections import threading import unittest from unittest.mock import MagicMock import uuid +from typing import Any, Callable, Deque +from types import TracebackType +from threading import Thread + import pytest from kazoo.exceptions import CancelledError from kazoo.exceptions import LockTimeout from kazoo.exceptions import NoNodeError -from kazoo.recipe.lock import Lock +from kazoo.recipe.lock import Lock, Semaphore from kazoo.testing import KazooTestCase from kazoo.tests import util as test_util @@ -17,22 +23,28 @@ class SleepBarrier(object): """A crappy spinning barrier.""" - def __init__(self, wait_for, sleep_func): + def __init__(self, wait_for: int, sleep_func: Callable[..., None]): self._wait_for = wait_for - self._arrived = collections.deque() + self._arrived: Deque[Thread] = collections.deque() self._sleep_func = sleep_func - def __enter__(self): + def __enter__(self) -> SleepBarrier: self._arrived.append(threading.current_thread()) return self - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: try: self._arrived.remove(threading.current_thread()) except ValueError: pass + return None - def wait(self): + def wait(self) -> None: while len(self._arrived) < self._wait_for: self._sleep_func(0.001) @@ -40,43 +52,45 @@ def wait(self): class KazooLockTests(KazooTestCase): thread_count = 20 - def __init__(self, *args, **kw): + def __init__(self, *args: None, **kw: None): super(KazooLockTests, self).__init__(*args, **kw) - self.threads_made = [] + self.threads_made: list[threading.Thread] = [] - def tearDown(self): + def tearDown(self) -> None: super(KazooLockTests, self).tearDown() while self.threads_made: t = self.threads_made.pop() t.join() @staticmethod - def make_condition(): + def make_condition() -> threading.Condition: return threading.Condition() @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def make_thread(self, *args, **kwargs): + def make_thread(self, *args: Any, **kwargs: Any) -> threading.Thread: t = threading.Thread(*args, **kwargs) t.daemon = True self.threads_made.append(t) return t @staticmethod - def make_wait(): + def make_wait() -> test_util.Wait: return test_util.Wait() - def setUp(self): + def setUp(self) -> None: super(KazooLockTests, self).setUp() self.lockpath = "/" + uuid.uuid4().hex self.condition = self.make_condition() self.released = self.make_event() - self.active_thread = None - self.cancelled_threads = [] + self.active_thread: str | None = None + self.cancelled_threads: list[str] = [] - def _thread_lock_acquire_til_event(self, name, lock, event): + def _thread_lock_acquire_til_event( + self, name: str, lock: Lock, event: threading.Event + ) -> None: try: with lock: with self.condition: @@ -96,7 +110,7 @@ def _thread_lock_acquire_til_event(self, name, lock, event): self.cancelled_threads.append(name) self.condition.notify_all() - def test_lock_one(self): + def test_lock_one(self) -> None: lock_name = uuid.uuid4().hex lock = self.client.Lock(self.lockpath, lock_name) event = self.make_event() @@ -127,7 +141,7 @@ def test_lock_one(self): self.released.wait() thread.join() - def test_lock(self): + def test_lock(self) -> None: threads = [] names = ["contender" + str(i) for i in range(5)] @@ -181,7 +195,7 @@ def test_lock(self): for thread in threads: thread.join() - def test_lock_reconnect(self): + def test_lock_reconnect(self) -> None: event = self.make_event() other_lock = self.client.Lock(self.lockpath, "contender") thread = self.make_thread( @@ -211,7 +225,7 @@ def test_lock_reconnect(self): event.set() thread.join() - def test_lock_non_blocking(self): + def test_lock_non_blocking(self) -> None: lock_name = uuid.uuid4().hex lock = self.client.Lock(self.lockpath, lock_name) event = self.make_event() @@ -235,7 +249,7 @@ def test_lock_non_blocking(self): event.set() thread.join() - def test_lock_fail_first_call(self): + def test_lock_fail_first_call(self) -> None: event1 = self.make_event() lock1 = self.client.Lock(self.lockpath, "one") thread1 = self.make_thread( @@ -248,12 +262,15 @@ def test_lock_fail_first_call(self): with self.condition: if not self.active_thread: self.condition.wait(5) - assert self.active_thread == "one" + assert ( + self.active_thread + == "one" # type: ignore[comparison-overlap] + ) assert lock1.contenders() == ["one"] event1.set() thread1.join() - def test_lock_cancel(self): + def test_lock_cancel(self) -> None: event1 = self.make_event() lock1 = self.client.Lock(self.lockpath, "one") thread1 = self.make_thread( @@ -266,7 +283,10 @@ def test_lock_cancel(self): with self.condition: if not self.active_thread: self.condition.wait(5) - assert self.active_thread == "one" + assert ( + self.active_thread + == "one" # type: ignore[comparison-overlap] + ) client2 = self._get_client() client2.start() @@ -296,7 +316,7 @@ def test_lock_cancel(self): thread1.join() client2.stop() - def test_lock_no_double_calls(self): + def test_lock_no_double_calls(self) -> None: lock1 = self.client.Lock(self.lockpath, "one") lock1.acquire() assert lock1.is_acquired is True @@ -305,7 +325,7 @@ def test_lock_no_double_calls(self): lock1.release() assert lock1.is_acquired is False - def test_lock_same_thread_no_block(self): + def test_lock_same_thread_no_block(self) -> None: lock = self.client.Lock(self.lockpath, "one") gotten = lock.acquire(blocking=False) assert gotten is True @@ -313,11 +333,11 @@ def test_lock_same_thread_no_block(self): gotten = lock.acquire(blocking=False) assert gotten is False - def test_lock_many_threads_no_block(self): + def test_lock_many_threads_no_block(self) -> None: lock = self.client.Lock(self.lockpath, "one") - attempts = collections.deque() + attempts: Deque[int] = collections.deque() - def _acquire(): + def _acquire() -> None: attempts.append(int(lock.acquire(blocking=False))) threads = [] @@ -332,14 +352,14 @@ def _acquire(): assert sum(list(attempts)) == 1 - def test_lock_many_threads(self): + def test_lock_many_threads(self) -> None: sleep_func = self.client.handler.sleep_func lock = self.client.Lock(self.lockpath, "one") - acquires = collections.deque() - differences = collections.deque() + acquires: Deque[int] = collections.deque() + differences: Deque[int] = collections.deque() barrier = SleepBarrier(self.thread_count, sleep_func) - def _acquire(): + def _acquire() -> None: # Wait until all threads are ready to go... with barrier as b: b.wait() @@ -366,18 +386,19 @@ def _acquire(): assert len(acquires) == self.thread_count assert list(differences) == [1] * self.thread_count - def test_lock_reacquire(self): + def test_lock_reacquire(self) -> None: lock = self.client.Lock(self.lockpath, "one") lock.acquire() lock.release() lock.acquire() lock.release() - def test_lock_ephemeral(self): + def test_lock_ephemeral(self) -> None: client1 = self._get_client() client1.start() lock = client1.Lock(self.lockpath, "ephemeral") lock.acquire(ephemeral=False) + assert lock.node is not None znode = self.lockpath + "/" + lock.node client1.stop() try: @@ -385,7 +406,7 @@ def test_lock_ephemeral(self): except NoNodeError: self.fail("NoNodeError raised unexpectedly!") - def test_lock_timeout(self): + def test_lock_timeout(self) -> None: timeout = 3 e = self.make_event() started = self.make_event() @@ -394,7 +415,9 @@ def test_lock_timeout(self): # that the main thread is going to wait to acquire the lock. lock1 = self.client.Lock(self.lockpath, "one") - def _thread(lock, event, timeout): + def _thread( + lock: threading.Lock, event: threading.Event, timeout: float + ) -> None: with lock: started.set() event.wait(timeout) @@ -426,7 +449,7 @@ def _thread(lock, event, timeout): t.join() client2.stop() - def test_read_lock(self): + def test_read_lock(self) -> None: # Test that we can obtain a read lock lock = self.client.ReadLock(self.lockpath, "reader one") gotten = lock.acquire(blocking=False) @@ -451,7 +474,7 @@ def test_read_lock(self): gotten = lock3.acquire(blocking=False) assert gotten is False - def test_write_lock(self): + def test_write_lock(self) -> None: # Test that we can obtain a write lock lock = self.client.WriteLock(self.lockpath, "writer") gotten = lock.acquire(blocking=False) @@ -468,7 +491,7 @@ def test_write_lock(self): gotten = lock2.acquire(blocking=False) assert gotten is False - def test_rw_lock(self): + def test_rw_lock(self) -> None: reader_event = self.make_event() reader_lock = self.client.ReadLock(self.lockpath, "reader") reader_thread = self.make_thread( @@ -508,8 +531,8 @@ def test_rw_lock(self): # release the lock and contenders should claim it in order lock.release() - for contender, contender_bits in contender_bits.items(): - _, event = contender_bits + for contender, bits in contender_bits.items(): + _, event = bits with self.condition: while not self.active_thread: @@ -530,44 +553,44 @@ def test_rw_lock(self): class TestSemaphore(KazooTestCase): - def __init__(self, *args, **kw): + def __init__(self, *args: Any, **kw: Any): super(TestSemaphore, self).__init__(*args, **kw) - self.threads_made = [] + self.threads_made: list[threading.Thread] = [] - def tearDown(self): + def tearDown(self) -> None: super(TestSemaphore, self).tearDown() while self.threads_made: t = self.threads_made.pop() t.join() @staticmethod - def make_condition(): + def make_condition() -> threading.Condition: return threading.Condition() @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def make_thread(self, *args, **kwargs): + def make_thread(self, *args: Any, **kwargs: Any) -> threading.Thread: t = threading.Thread(*args, **kwargs) t.daemon = True self.threads_made.append(t) return t - def setUp(self): + def setUp(self) -> None: super(TestSemaphore, self).setUp() self.lockpath = "/" + uuid.uuid4().hex self.condition = self.make_condition() self.released = self.make_event() self.active_thread = None - self.cancelled_threads = [] + self.cancelled_threads: list[str] = [] - def test_basic(self): + def test_basic(self) -> None: sem1 = self.client.Semaphore(self.lockpath) sem1.acquire() sem1.release() - def test_lock_one(self): + def test_lock_one(self) -> None: sem1 = self.client.Semaphore(self.lockpath, max_leases=1) sem2 = self.client.Semaphore(self.lockpath, max_leases=1) started = self.make_event() @@ -575,7 +598,7 @@ def test_lock_one(self): sem1.acquire() - def sema_one(): + def sema_one() -> None: started.set() with sem2: event.set() @@ -591,7 +614,7 @@ def sema_one(): assert event.is_set() is True thread.join() - def test_non_blocking(self): + def test_non_blocking(self) -> None: sem1 = self.client.Semaphore( self.lockpath, identifier="sem1", max_leases=2 ) @@ -613,7 +636,7 @@ def test_non_blocking(self): sem1.release() sem3.release() - def test_non_blocking_release(self): + def test_non_blocking_release(self) -> None: sem1 = self.client.Semaphore( self.lockpath, identifier="sem1", max_leases=1 ) @@ -627,11 +650,11 @@ def test_non_blocking_release(self): sem1.release() sem2.release() - def test_holders(self): + def test_holders(self) -> None: started = self.make_event() event = self.make_event() - def sema_one(): + def sema_one() -> None: with self.client.Semaphore(self.lockpath, "fred", max_leases=1): started.set() event.wait() @@ -645,14 +668,14 @@ def sema_one(): event.set() thread.join() - def test_semaphore_cancel(self): + def test_semaphore_cancel(self) -> None: sem1 = self.client.Semaphore(self.lockpath, "fred", max_leases=1) sem2 = self.client.Semaphore(self.lockpath, "george", max_leases=1) sem1.acquire() started = self.make_event() event = self.make_event() - def sema_one(): + def sema_one() -> None: started.set() try: with sem2: @@ -670,7 +693,7 @@ def sema_one(): assert event.is_set() thread.join() - def test_multiple_acquire_and_release(self): + def test_multiple_acquire_and_release(self) -> None: sem1 = self.client.Semaphore(self.lockpath, "fred", max_leases=1) sem1.acquire() sem1.acquire() @@ -678,7 +701,7 @@ def test_multiple_acquire_and_release(self): assert sem1.release() assert not sem1.release() - def test_handle_session_loss(self): + def test_handle_session_loss(self) -> None: expire_semaphore = self.client.Semaphore( self.lockpath, "fred", max_leases=1 ) @@ -692,7 +715,7 @@ def test_handle_session_loss(self): event = self.make_event() event2 = self.make_event() - def sema_one(): + def sema_one() -> None: started.set() with expire_semaphore: event.set() @@ -707,7 +730,7 @@ def sema_one(): # Fired in a separate thread to make sure we can see the effect expired = self.make_event() - def expire(): + def expire() -> None: self.expire_session(self.make_event) expired.set() @@ -726,7 +749,7 @@ def expire(): for t in (thread1, thread2): t.join() - def test_inconsistent_max_leases(self): + def test_inconsistent_max_leases(self) -> None: sem1 = self.client.Semaphore(self.lockpath, max_leases=1) sem2 = self.client.Semaphore(self.lockpath, max_leases=2) @@ -734,7 +757,7 @@ def test_inconsistent_max_leases(self): with pytest.raises(ValueError): sem2.acquire() - def test_inconsistent_max_leases_other_data(self): + def test_inconsistent_max_leases_other_data(self) -> None: sem1 = self.client.Semaphore(self.lockpath, max_leases=1) sem2 = self.client.Semaphore(self.lockpath, max_leases=2) @@ -745,14 +768,14 @@ def test_inconsistent_max_leases_other_data(self): # sem2 thinks it's ok to have two lease holders assert sem2.acquire(blocking=False) - def test_reacquire(self): + def test_reacquire(self) -> None: lock = self.client.Semaphore(self.lockpath) lock.acquire() lock.release() lock.acquire() lock.release() - def test_acquire_after_cancelled(self): + def test_acquire_after_cancelled(self) -> None: lock = self.client.Semaphore(self.lockpath) assert lock.acquire() is True assert lock.release() is True @@ -760,7 +783,7 @@ def test_acquire_after_cancelled(self): assert lock.cancelled is True assert lock.acquire() is True - def test_timeout(self): + def test_timeout(self) -> None: timeout = 3 e = self.make_event() started = self.make_event() @@ -769,7 +792,9 @@ def test_timeout(self): # that the main thread is going to wait to acquire the lock. sem1 = self.client.Semaphore(self.lockpath, "one") - def _thread(sem, event, timeout): + def _thread( + sem: Semaphore, event: threading.Event, timeout: float + ) -> None: with sem: started.set() event.wait(timeout) @@ -800,7 +825,7 @@ def _thread(sem, event, timeout): class TestSequence(unittest.TestCase): - def test_get_predecessor(self): + def test_get_predecessor(self) -> None: """Validate selection of predecessors.""" goLock = "_c_8eb60557ba51e0da67eefc47467d3f34-lock-0000000031" pyLock = "514e5a831836450cb1a56c741e990fd8__lock__0000000032" @@ -810,7 +835,7 @@ def test_get_predecessor(self): lock = Lock(client, "test") assert lock._get_predecessor(pyLock) is None - def test_get_predecessor_go(self): + def test_get_predecessor_go(self) -> None: """Test selection of predecessor when instructed to consider go-zk locks. """ diff --git a/kazoo/tests/test_partitioner.py b/kazoo/tests/test_partitioner.py index b2b91c8bb..35076ee23 100644 --- a/kazoo/tests/test_partitioner.py +++ b/kazoo/tests/test_partitioner.py @@ -1,11 +1,15 @@ +from __future__ import annotations + import uuid import threading import time from unittest.mock import patch +from kazoo.client import KazooClient +from kazoo.interfaces import Lockable from kazoo.exceptions import LockTimeout from kazoo.testing import KazooTestCase -from kazoo.recipe.partitioner import PartitionState +from kazoo.recipe.partitioner import PartitionState, SetPartitioner class SlowLockMock: @@ -13,14 +17,19 @@ class SlowLockMock: default_delay_time = 3 - def __init__(self, client, lock, delay_time=None): + def __init__( + self, + client: KazooClient, + lock: Lockable, + delay_time: float | None = None, + ): self._client = client self._lock = lock self.delay_time = ( self.default_delay_time if delay_time is None else delay_time ) - def acquire(self, timeout=None): + def acquire(self, timeout: float | None = None) -> bool: sleep = self._client.handler.sleep_func sleep(self.delay_time) @@ -37,28 +46,32 @@ def acquire(self, timeout=None): raise LockTimeout("Mocked slow lock has timed out.") - def release(self): + def release(self) -> None: self._lock.release() +PartitionData = int +Partitioner = SetPartitioner[PartitionData] + + class KazooPartitionerTests(KazooTestCase): @staticmethod - def make_event(): + def make_event() -> threading.Event: return threading.Event() - def setUp(self): + def setUp(self) -> None: super(KazooPartitionerTests, self).setUp() self.path = "/" + uuid.uuid4().hex - self.__partitioners = [] + self.__partitioners: list[Partitioner] = [] - def test_party_of_one(self): + def test_party_of_one(self) -> None: self.__create_partitioner(size=3) self.__wait_for_acquire() self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0, 1, 2]) self.__finish() - def test_party_of_two(self): + def test_party_of_two(self) -> None: for i in range(2): self.__create_partitioner(size=2, identifier=str(i)) @@ -70,7 +83,7 @@ def test_party_of_two(self): assert self.__partitioners[1].release self.__partitioners[1].finish() - def test_party_expansion(self): + def test_party_expansion(self) -> None: for i in range(2): self.__create_partitioner(size=3, identifier=str(i)) @@ -88,7 +101,7 @@ def test_party_expansion(self): self.__assert_state( PartitionState.RELEASE, partitioners=self.__partitioners[:-1] ) - for partitioner in self.__partitioners[-1]: + for partitioner in self.__partitioners[:-1]: assert partitioner.state_change_event.is_set() self.__release(self.__partitioners[:-1]) @@ -97,7 +110,7 @@ def test_party_expansion(self): self.__finish() - def test_more_members_than_set_items(self): + def test_more_members_than_set_items(self) -> None: for i in range(2): self.__create_partitioner(size=1, identifier=str(i)) @@ -107,7 +120,7 @@ def test_more_members_than_set_items(self): self.__finish() - def test_party_session_failure(self): + def test_party_session_failure(self) -> None: partitioner = self.__create_partitioner(size=3) self.__wait_for_acquire() assert partitioner.state == PartitionState.ACQUIRED @@ -116,7 +129,7 @@ def test_party_session_failure(self): partitioner.release_set() assert partitioner.failed is True - def test_connection_loss(self): + def test_connection_loss(self) -> None: self.__create_partitioner(identifier="0", size=3) self.__create_partitioner(identifier="1", size=3) @@ -146,10 +159,10 @@ def test_connection_loss(self): self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0], [1], [2]) - def test_race_condition_new_partitioner_during_the_lock(self): - locks = {} + def test_race_condition_new_partitioner_during_the_lock(self) -> None: + locks: dict[str, Lockable] = {} - def get_lock(path): + def get_lock(path: str) -> SlowLockMock: lock = locks.setdefault(path, self.client.handler.lock_object()) return SlowLockMock(self.client, lock) @@ -175,10 +188,10 @@ def get_lock(path): self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0], [1]) - def test_race_condition_new_partitioner_steals_the_lock(self): - locks = {} + def test_race_condition_new_partitioner_steals_the_lock(self) -> None: + locks: dict[str, Lockable] = {} - def get_lock(path): + def get_lock(path: str) -> SlowLockMock: new_lock = self.client.handler.lock_object() lock = locks.setdefault(path, new_lock) @@ -214,7 +227,9 @@ def get_lock(path): self.__assert_state(PartitionState.ACQUIRED) self.__assert_partitions([0], [1]) - def __create_partitioner(self, size, identifier=None): + def __create_partitioner( + self, size: int, identifier: str | None = None + ) -> Partitioner: partitioner = self.client.SetPartitioner( self.path, set=range(size), @@ -224,34 +239,38 @@ def __create_partitioner(self, size, identifier=None): self.__partitioners.append(partitioner) return partitioner - def __wait_for_acquire(self): + def __wait_for_acquire(self) -> None: for partitioner in self.__partitioners: partitioner.wait_for_acquire(14) - def __assert_state(self, state, partitioners=None): + def __assert_state( + self, + state: PartitionState, + partitioners: list[Partitioner] | None = None, + ) -> None: if partitioners is None: partitioners = self.__partitioners for partitioner in partitioners: assert partitioner.state == state - def __assert_partitions(self, *partitions): + def __assert_partitions(self, *partitions: list[PartitionData]) -> None: assert len(partitions) == len(self.__partitioners) for partitioner, own_partitions in zip( self.__partitioners, partitions ): assert list(partitioner) == own_partitions - def __wait(self): + def __wait(self) -> None: time.sleep(0.1) - def __release(self, partitioners=None): + def __release(self, partitioners: list[Partitioner] | None = None) -> None: if partitioners is None: partitioners = self.__partitioners for partitioner in partitioners: partitioner.release_set() - def __finish(self): + def __finish(self) -> None: for partitioner in self.__partitioners: partitioner.finish() diff --git a/kazoo/tests/test_party.py b/kazoo/tests/test_party.py index 1b32523c4..f503eb333 100644 --- a/kazoo/tests/test_party.py +++ b/kazoo/tests/test_party.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import uuid from kazoo.testing import KazooTestCase class KazooPartyTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooPartyTests, self).setUp() self.path = "/" + uuid.uuid4().hex - def test_party(self): + def test_party(self) -> None: parties = [self.client.Party(self.path, "p%s" % i) for i in range(5)] one_party = parties[0] @@ -31,7 +33,7 @@ def test_party(self): assert set(party) == participants assert len(party) == len(participants) - def test_party_reuse_node(self): + def test_party_reuse_node(self) -> None: party = self.client.Party(self.path, "p1") self.client.ensure_path(self.path) self.client.create(party.create_path) @@ -39,24 +41,26 @@ def test_party_reuse_node(self): assert party.participating is True party.leave() assert party.participating is False - assert len(party) == 0 + # This appears to be an issue with mypy + assert len(party) == 0 # type: ignore[unreachable] - def test_party_vanishing_node(self): + def test_party_vanishing_node(self) -> None: party = self.client.Party(self.path, "p1") party.join() assert party.participating is True self.client.delete(party.create_path) party.leave() assert party.participating is False - assert len(party) == 0 + # This appears to be an issue with mypy + assert len(party) == 0 # type: ignore[unreachable] class KazooShallowPartyTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooShallowPartyTests, self).setUp() self.path = "/" + uuid.uuid4().hex - def test_party(self): + def test_party(self) -> None: parties = [ self.client.ShallowParty(self.path, "p%s" % i) for i in range(5) ] diff --git a/kazoo/tests/test_paths.py b/kazoo/tests/test_paths.py index 438c2ecad..a8064d486 100644 --- a/kazoo/tests/test_paths.py +++ b/kazoo/tests/test_paths.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from unittest import TestCase import pytest @@ -6,67 +8,67 @@ class NormPathTestCase(TestCase): - def test_normpath(self): + def test_normpath(self) -> None: assert paths.normpath("/a/b") == "/a/b" - def test_normpath_empty(self): + def test_normpath_empty(self) -> None: assert paths.normpath("") == "" - def test_normpath_unicode(self): + def test_normpath_unicode(self) -> None: assert paths.normpath("/\xe4/b") == "/\xe4/b" - def test_normpath_dots(self): + def test_normpath_dots(self) -> None: assert paths.normpath("/a./b../c") == "/a./b../c" - def test_normpath_slash(self): + def test_normpath_slash(self) -> None: assert paths.normpath("/") == "/" - def test_normpath_multiple_slashes(self): + def test_normpath_multiple_slashes(self) -> None: assert paths.normpath("//") == "/" assert paths.normpath("//a/b") == "/a/b" assert paths.normpath("/a//b//") == "/a/b" assert paths.normpath("//a////b///c/") == "/a/b/c" - def test_normpath_relative(self): + def test_normpath_relative(self) -> None: with pytest.raises(ValueError): paths.normpath("./a/b") with pytest.raises(ValueError): paths.normpath("/a/../b") - def test_normpath_trailing(self): + def test_normpath_trailing(self) -> None: assert paths.normpath("/", trailing=True) == "/" class JoinTestCase(TestCase): - def test_join(self): + def test_join(self) -> None: assert paths.join("/a") == "/a" assert paths.join("/a", "b/") == "/a/b/" assert paths.join("/a", "b", "c") == "/a/b/c" - def test_join_empty(self): + def test_join_empty(self) -> None: assert paths.join("") == "" assert paths.join("", "a", "b") == "a/b" assert paths.join("/a", "", "b/", "c") == "/a/b/c" - def test_join_absolute(self): + def test_join_absolute(self) -> None: assert paths.join("/a/b", "/c") == "/c" class IsAbsTestCase(TestCase): - def test_isabs(self): + def test_isabs(self) -> None: assert paths.isabs("/") is True assert paths.isabs("/a") is True assert paths.isabs("/a//b/c") is True assert paths.isabs("//a/b") is True - def test_isabs_false(self): + def test_isabs_false(self) -> None: assert paths.isabs("") is False assert paths.isabs("a/") is False assert paths.isabs("a/../") is False class BaseNameTestCase(TestCase): - def test_basename(self): + def test_basename(self) -> None: assert paths.basename("") == "" assert paths.basename("/") == "" assert paths.basename("//a") == "a" @@ -75,7 +77,7 @@ def test_basename(self): class PrefixRootTestCase(TestCase): - def test_prefix_root(self): + def test_prefix_root(self) -> None: assert paths._prefix_root("/a/", "b/c") == "/a/b/c" assert paths._prefix_root("/a/b", "c/d") == "/a/b/c/d" assert paths._prefix_root("/a", "/b/c") == "/a/b/c" @@ -83,7 +85,7 @@ def test_prefix_root(self): class NormRootTestCase(TestCase): - def test_norm_root(self): + def test_norm_root(self) -> None: assert paths._norm_root("") == "/" assert paths._norm_root("/") == "/" assert paths._norm_root("//a") == "/a" diff --git a/kazoo/tests/test_queue.py b/kazoo/tests/test_queue.py index b4b8455fa..0d86f57a7 100644 --- a/kazoo/tests/test_queue.py +++ b/kazoo/tests/test_queue.py @@ -1,36 +1,42 @@ +from __future__ import annotations + import uuid +from typing import Any + import pytest +from kazoo.interfaces import Event +from kazoo.recipe.queue import LockingQueue, Queue from kazoo.testing import KazooTestCase from kazoo.tests.util import CI_ZK_VERSION class KazooQueueTests(KazooTestCase): - def _makeOne(self): + def _makeOne(self) -> Queue: path = "/" + uuid.uuid4().hex return self.client.Queue(path) - def test_queue_validation(self): + def test_queue_validation(self) -> None: queue = self._makeOne() with pytest.raises(TypeError): - queue.put({}) + queue.put({}) # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", b"100") + queue.put(b"one", b"100") # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", 10.0) + queue.put(b"one", 10.0) # type: ignore[arg-type] with pytest.raises(ValueError): queue.put(b"one", -100) with pytest.raises(ValueError): queue.put(b"one", 100000) - def test_empty_queue(self): + def test_empty_queue(self) -> None: queue = self._makeOne() assert len(queue) == 0 assert queue.get() is None assert len(queue) == 0 - def test_queue(self): + def test_queue(self) -> None: queue = self._makeOne() queue.put(b"one") queue.put(b"two") @@ -42,7 +48,7 @@ def test_queue(self): assert queue.get() == b"three" assert len(queue) == 0 - def test_priority(self): + def test_priority(self) -> None: queue = self._makeOne() queue.put(b"four", priority=101) queue.put(b"one", priority=0) @@ -56,7 +62,7 @@ def test_priority(self): class KazooLockingQueueTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: KazooTestCase.setUp(self) skip = False if CI_ZK_VERSION and CI_ZK_VERSION < (3, 4): @@ -70,42 +76,42 @@ def setUp(self): if skip: pytest.skip("Must use Zookeeper 3.4 or above") - def _makeOne(self): + def _makeOne(self) -> LockingQueue: path = "/" + uuid.uuid4().hex return self.client.LockingQueue(path) - def test_queue_validation(self): + def test_queue_validation(self) -> None: queue = self._makeOne() with pytest.raises(TypeError): - queue.put({}) + queue.put({}) # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", b"100") + queue.put(b"one", b"100") # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put(b"one", 10.0) + queue.put(b"one", 10.0) # type: ignore[arg-type] with pytest.raises(ValueError): queue.put(b"one", -100) with pytest.raises(ValueError): queue.put(b"one", 100000) with pytest.raises(TypeError): - queue.put_all({}) + queue.put_all({}) # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put_all([{}]) + queue.put_all([{}]) # type: ignore[list-item] with pytest.raises(TypeError): - queue.put_all([b"one"], b"100") + queue.put_all([b"one"], b"100") # type: ignore[arg-type] with pytest.raises(TypeError): - queue.put_all([b"one"], 10.0) + queue.put_all([b"one"], 10.0) # type: ignore[arg-type] with pytest.raises(ValueError): queue.put_all([b"one"], -100) with pytest.raises(ValueError): queue.put_all([b"one"], 100000) - def test_empty_queue(self): + def test_empty_queue(self) -> None: queue = self._makeOne() assert len(queue) == 0 assert queue.get(0) is None assert len(queue) == 0 - def test_queue(self): + def test_queue(self) -> None: queue = self._makeOne() queue.put(b"one") queue.put_all([b"two", b"three"]) @@ -130,7 +136,7 @@ def test_queue(self): assert not queue.consume() assert len(queue) == 0 - def test_consume(self): + def test_consume(self) -> None: queue = self._makeOne() queue.put(b"one") @@ -139,7 +145,7 @@ def test_consume(self): assert queue.consume() assert not queue.consume() - def test_release(self): + def test_release(self) -> None: queue = self._makeOne() queue.put(b"one") @@ -152,7 +158,7 @@ def test_release(self): assert not queue.release() assert len(queue) == 0 - def test_holds_lock(self): + def test_holds_lock(self) -> None: queue = self._makeOne() assert not queue.holds_lock() @@ -162,7 +168,7 @@ def test_holds_lock(self): queue.consume() assert not queue.holds_lock() - def test_priority(self): + def test_priority(self) -> None: queue = self._makeOne() queue.put(b"four", priority=101) queue.put(b"one", priority=0) @@ -178,16 +184,16 @@ def test_priority(self): assert queue.get(1) == b"four" assert queue.consume() - def test_concurrent_execution(self): + def test_concurrent_execution(self) -> None: queue = self._makeOne() - value1 = [] - value2 = [] - value3 = [] + value1: list[bytes | None] = [] + value2: list[bytes | None] = [] + value3: list[bytes | None] = [] event1 = self.client.handler.event_object() event2 = self.client.handler.event_object() event3 = self.client.handler.event_object() - def get_concurrently(value, event): + def get_concurrently(value: list[Any], event: Event) -> None: q = self.client.LockingQueue(queue.path) value.append(q.get(0.1)) event.set() diff --git a/kazoo/tests/test_retry.py b/kazoo/tests/test_retry.py index acbe5bd9d..1c3f92820 100644 --- a/kazoo/tests/test_retry.py +++ b/kazoo/tests/test_retry.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Any from unittest import mock import pytest @@ -6,16 +9,19 @@ from kazoo import retry as kr -def _make_retry(*args, **kwargs): +def _make_retry(*args: Any, **kwargs: Any) -> kr.KazooRetry: """Return a KazooRetry instance with a dummy sleep function.""" - def _sleep_func(_time): + def _sleep_func(_time: float) -> None: pass - return kr.KazooRetry(*args, sleep_func=_sleep_func, **kwargs) + # FIXME better way of doing this? Use TypedDict perhaps? + return kr.KazooRetry( + *args, sleep_func=_sleep_func, **kwargs # type: ignore[misc] + ) -def _make_try_func(times=1): +def _make_try_func(times: int = 1) -> mock.Mock: """Returns a function that raises ForceRetryError `times` time before returning None. """ @@ -25,7 +31,7 @@ def _make_try_func(times=1): return callmock -def test_call(): +def test_call() -> None: retry = _make_retry(delay=0, max_tries=2) func = _make_try_func() retry(func, "foo", bar="baz") @@ -35,7 +41,7 @@ def test_call(): ] -def test_reset(): +def test_reset() -> None: retry = _make_retry(delay=0, max_tries=2) func = _make_try_func() retry(func) @@ -46,7 +52,7 @@ def test_reset(): assert retry._attempts == 0 -def test_too_many_tries(): +def test_too_many_tries() -> None: retry = _make_retry(delay=0, max_tries=10) func = _make_try_func(times=999) with pytest.raises(kr.RetryFailedError): @@ -56,7 +62,7 @@ def test_too_many_tries(): ), "Called 10 times, failed _attempts 10" -def test_maximum_delay(): +def test_maximum_delay() -> None: retry = _make_retry(delay=10, max_tries=100, max_jitter=0) func = _make_try_func(times=2) retry(func) @@ -71,27 +77,27 @@ def test_maximum_delay(): assert isinstance(retry._cur_delay, float) -def test_copy(): +def test_copy() -> None: retry = _make_retry() rcopy = retry.copy() assert rcopy is not retry assert rcopy.sleep_func is retry.sleep_func -def test_connection_closed(): +def test_connection_closed() -> None: retry = _make_retry() - def testit(): + def testit() -> None: raise ke.ConnectionClosedError with pytest.raises(ke.ConnectionClosedError): retry(testit) -def test_session_expired(): +def test_session_expired() -> None: retry = _make_retry(max_tries=1) - def testit(): + def testit() -> None: raise ke.SessionExpiredError with pytest.raises(kr.RetryFailedError): diff --git a/kazoo/tests/test_sasl.py b/kazoo/tests/test_sasl.py index 6daa2a0b2..c70806fba 100644 --- a/kazoo/tests/test_sasl.py +++ b/kazoo/tests/test_sasl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess import time @@ -13,7 +15,7 @@ class TestLegacySASLDigestAuthentication(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: try: import puresasl # NOQA except ImportError: @@ -29,11 +31,11 @@ def setUp(self): if not version or version < (3, 4): pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() os.environ.pop("ZOOKEEPER_JAAS_AUTH", None) - def test_connect_sasl_auth(self): + def test_connect_sasl_auth(self) -> None: from kazoo.security import make_acl username = "jaasuser" @@ -56,14 +58,14 @@ def test_connect_sasl_auth(self): client.stop() client.close() - def test_invalid_sasl_auth(self): + def test_invalid_sasl_auth(self) -> None: client = self._get_client(auth_data=[("sasl", "baduser:badpassword")]) with pytest.raises(AuthFailedError): client.start() class TestSASLDigestAuthentication(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: try: import puresasl # NOQA except ImportError: @@ -79,11 +81,11 @@ def setUp(self): if not version or version < (3, 4): pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() os.environ.pop("ZOOKEEPER_JAAS_AUTH", None) - def test_connect_sasl_auth(self): + def test_connect_sasl_auth(self) -> None: from kazoo.security import make_acl username = "jaasuser" @@ -110,7 +112,7 @@ def test_connect_sasl_auth(self): client.stop() client.close() - def test_invalid_sasl_auth(self): + def test_invalid_sasl_auth(self) -> None: client = self._get_client( sasl_options={ "mechanism": "DIGEST-MD5", @@ -123,13 +125,17 @@ def test_invalid_sasl_auth(self): class TestSASLGSSAPIAuthentication(KazooTestHarness): - def setUp(self): + def setUp(self) -> None: + # FIXME Under what circumstances is it ok for these to not be + # available? pip install is a thing. try: - import puresasl # NOQA + import puresasl except ImportError: pytest.skip("PureSASL not available.") try: - import kerberos # NOQA + # FIXME Hound objects to import not found as it thinks it's a + # syntax error. I don't know why it thinks that. + import kerberos # type: ignore except ImportError: pytest.skip("Kerberos support not available.") if not os.environ.get("KRB5_TEST_ENV"): @@ -145,11 +151,11 @@ def setUp(self): if not version or version < (3, 4): pytest.skip("Must use Zookeeper 3.4 or above") - def tearDown(self): + def tearDown(self) -> None: self.teardown_zookeeper() os.environ.pop("ZOOKEEPER_JAAS_AUTH", None) - def test_connect_gssapi_auth(self): + def test_connect_gssapi_auth(self) -> None: from kazoo.security import make_acl # Ensure we have a client ticket @@ -177,7 +183,7 @@ def test_connect_gssapi_auth(self): client.stop() client.close() - def test_invalid_gssapi_auth(self): + def test_invalid_gssapi_auth(self) -> None: # Request a post-datated ticket, so that it is currently invalid. subprocess.check_call( [ diff --git a/kazoo/tests/test_security.py b/kazoo/tests/test_security.py index bc45483af..ba21862b6 100644 --- a/kazoo/tests/test_security.py +++ b/kazoo/tests/test_security.py @@ -1,19 +1,23 @@ +from __future__ import annotations + import unittest -from kazoo.security import Permissions +from typing import Any + +from kazoo.security import ACL, Permissions class TestACL(unittest.TestCase): - def _makeOne(self, *args, **kwargs): + def _makeOne(self, *args: Any, **kwargs: Any) -> ACL: from kazoo.security import make_acl return make_acl(*args, **kwargs) - def test_read_acl(self): + def test_read_acl(self) -> None: acl = self._makeOne("digest", ":", read=True) - assert acl.perms & Permissions.READ == Permissions.READ + assert (acl.perms & Permissions.READ) == Permissions.READ - def test_all_perms(self): + def test_all_perms(self) -> None: acl = self._makeOne( "digest", ":", @@ -30,25 +34,28 @@ def test_all_perms(self): Permissions.DELETE, Permissions.ADMIN, ]: - assert acl.perms & perm == perm + assert (acl.perms & perm) == perm + + def test_perm_listing(self) -> None: + # FIXME ACL(n, Id) isn't an API, so why do we do this? - def test_perm_listing(self): - from kazoo.security import ACL + from kazoo.security import Id - f = ACL(15, "fred") + f = ACL(15, Id("fred", "bill")) assert "READ" in f.acl_list assert "WRITE" in f.acl_list assert "CREATE" in f.acl_list assert "DELETE" in f.acl_list - f = ACL(16, "fred") + f = ACL(16, Id("fred", "bill")) assert "ADMIN" in f.acl_list - f = ACL(31, "george") + f = ACL(31, Id("fred", "bill")) assert "ALL" in f.acl_list - def test_perm_repr(self): - from kazoo.security import ACL + def test_perm_repr(self) -> None: + # FIXME See above + from kazoo.security import Id - f = ACL(16, "fred") + f = ACL(16, Id("fred", "bill")) assert "ACL(perms=16, acl_list=['ADMIN']" in repr(f) diff --git a/kazoo/tests/test_selectors_select.py b/kazoo/tests/test_selectors_select.py index 99dd44ae7..8d367b179 100644 --- a/kazoo/tests/test_selectors_select.py +++ b/kazoo/tests/test_selectors_select.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ The official python select function test case copied from python source to test the selector_select function. @@ -8,7 +10,10 @@ import sys import unittest +from typing import cast + from kazoo.handlers.utils import selector_select +from kazoo.interfaces import HasFileNo select = selector_select @@ -21,10 +26,10 @@ class Nope: pass class Almost: - def fileno(self): + def fileno(self) -> str: return "fileno" - def test_error_conditions(self): + def test_error_conditions(self) -> None: self.assertRaises(TypeError, select, 1, 2, 3) self.assertRaises(TypeError, select, [self.Nope()], [], []) self.assertRaises(TypeError, select, [self.Almost()], [], []) @@ -36,41 +41,44 @@ def test_error_conditions(self): sys.platform.startswith("freebsd"), "skip because of a FreeBSD bug: kern/155606", ) - def test_errno(self): + def test_errno(self) -> None: with open(__file__, "rb") as fp: fd = fp.fileno() fp.close() self.assertRaises(ValueError, select, [fd], [], [], 0) - def test_returned_list_identity(self): + def test_returned_list_identity(self) -> None: # See issue #8329 r, w, x = select([], [], [], 1) self.assertIsNot(r, w) self.assertIsNot(r, x) self.assertIsNot(w, x) - def test_select(self): + def test_select(self) -> None: cmd = "for i in 0 1 2 3 4 5 6 7 8 9; do echo testing...; sleep 1; done" p = os.popen(cmd, "r") for tout in (0, 1, 2, 4, 8, 16) + (None,) * 10: - rfd, wfd, xfd = select([p], [], [], tout) + rfd, wfd, xfd = select([cast("HasFileNo", p)], [], [], tout) if (rfd, wfd, xfd) == ([], [], []): continue - if (rfd, wfd, xfd) == ([p], [], []): + if (rfd, wfd, xfd) == ([cast("HasFileNo", p)], [], []): line = p.readline() if not line: break continue - self.fail("Unexpected return values from select():", rfd, wfd, xfd) + self.fail( + "Unexpected return values from select(): %s %s %s" + % (rfd, wfd, xfd) + ) p.close() # Issue 16230: Crash on select resized list - def test_select_mutated(self): + def test_select_mutated(self) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - a = [] + a: list[HasFileNo] = [] class F: - def fileno(self): + def fileno(self) -> int: del a[-1] return s.fileno() diff --git a/kazoo/tests/test_threading_handler.py b/kazoo/tests/test_threading_handler.py index eac4ea9f8..f4ac5d434 100644 --- a/kazoo/tests/test_threading_handler.py +++ b/kazoo/tests/test_threading_handler.py @@ -1,41 +1,45 @@ +from __future__ import annotations + +import socket import threading import unittest + +from typing import Any, Type from unittest.mock import Mock import pytest +from kazoo.handlers.threading import AsyncResult, SequentialThreadingHandler +from kazoo.handlers.utils import create_tcp_socket -class TestThreadingHandler(unittest.TestCase): - def _makeOne(self, *args): - from kazoo.handlers.threading import SequentialThreadingHandler +class TestThreadingHandler(unittest.TestCase): + def _makeOne(self, *args: Any) -> SequentialThreadingHandler: return SequentialThreadingHandler(*args) - def _getAsync(self, *args): - from kazoo.handlers.threading import AsyncResult - + def _getAsync(self) -> Type[AsyncResult]: return AsyncResult - def test_proper_threading(self): + def test_proper_threading(self) -> None: h = self._makeOne() h.start() # In Python 3.3 _Event is gone, before Event is function event_class = getattr(threading, "_Event", threading.Event) assert isinstance(h.event_object(), event_class) - def test_matching_async(self): + def test_matching_async(self) -> None: h = self._makeOne() h.start() async_result = self._getAsync() assert isinstance(h.async_result(), async_result) - def test_exception_raising(self): + def test_exception_raising(self) -> None: h = self._makeOne() with pytest.raises(h.timeout_exception): raise h.timeout_exception("This is a timeout") - def test_double_start_stop(self): + def test_double_start_stop(self) -> None: h = self._makeOne() h.start() assert h._running is True @@ -44,14 +48,11 @@ def test_double_start_stop(self): h.stop() assert h._running is False - def test_huge_file_descriptor(self): + def test_huge_file_descriptor(self) -> None: try: import resource except ImportError: self.skipTest("resource module unavailable on this platform") - import socket - from kazoo.handlers.utils import create_tcp_socket - try: resource.setrlimit(resource.RLIMIT_NOFILE, (4096, 4096)) except (ValueError, resource.error): @@ -71,17 +72,13 @@ def test_huge_file_descriptor(self): class TestThreadingAsync(unittest.TestCase): - def _makeOne(self, *args): - from kazoo.handlers.threading import AsyncResult - + def _makeOne(self, *args: Any) -> AsyncResult: return AsyncResult(*args) - def _makeHandler(self): - from kazoo.handlers.threading import SequentialThreadingHandler - + def _makeHandler(self) -> SequentialThreadingHandler: return SequentialThreadingHandler() - def test_ready(self): + def test_ready(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -91,7 +88,7 @@ def test_ready(self): assert async_result.successful() is True assert async_result.exception is None - def test_callback_queued(self): + def test_callback_queued(self) -> None: mock_handler = Mock() mock_handler.completion_queue = Mock() async_result = self._makeOne(mock_handler) @@ -101,7 +98,7 @@ def test_callback_queued(self): assert mock_handler.completion_queue.put.called - def test_set_exception(self): + def test_set_exception(self) -> None: mock_handler = Mock() mock_handler.completion_queue = Mock() async_result = self._makeOne(mock_handler) @@ -111,7 +108,7 @@ def test_set_exception(self): assert isinstance(async_result.exception, ImportError) assert mock_handler.completion_queue.put.called - def test_get_wait_while_setting(self): + def test_get_wait_while_setting(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -119,7 +116,7 @@ def test_get_wait_while_setting(self): bv = threading.Event() cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: bv.set() val = async_result.get() lst.append(val) @@ -134,7 +131,7 @@ def wait_for_val(): assert lst == ["fred"] th.join() - def test_get_with_nowait(self): + def test_get_with_nowait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) timeout = self._makeHandler().timeout_exception @@ -145,7 +142,7 @@ def test_get_with_nowait(self): with pytest.raises(timeout): async_result.get_nowait() - def test_get_with_exception(self): + def test_get_with_exception(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -153,7 +150,7 @@ def test_get_with_exception(self): bv = threading.Event() cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: bv.set() try: val = async_result.get() @@ -167,20 +164,20 @@ def wait_for_val(): th.start() bv.wait() - async_result.set_exception(ImportError) + async_result.set_exception(ImportError()) cv.wait() assert lst == ["oops"] th.join() - def test_wait(self): + def test_wait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) - lst = [] + lst: list[bool | str] = [] bv = threading.Event() cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: bv.set() try: val = async_result.wait(10) @@ -199,7 +196,7 @@ def wait_for_val(): assert lst == [True] th.join() - def test_wait_race(self): + def test_wait_race(self) -> None: """Test that there is no race condition in `IAsyncResult.wait()`. Guards against the reappearance of: @@ -212,7 +209,7 @@ def test_wait_race(self): cv = threading.Event() - def wait_for_val(): + def wait_for_val() -> None: # NB: should not sleep async_result.wait(20) cv.set() @@ -227,7 +224,7 @@ def wait_for_val(): assert cv.is_set() is True th.join() - def test_set_before_wait(self): + def test_set_before_wait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) @@ -235,7 +232,7 @@ def test_set_before_wait(self): cv = threading.Event() async_result.set("fred") - def wait_for_val(): + def wait_for_val() -> None: val = async_result.get() lst.append(val) cv.set() @@ -246,15 +243,15 @@ def wait_for_val(): assert lst == ["fred"] th.join() - def test_set_exc_before_wait(self): + def test_set_exc_before_wait(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) lst = [] cv = threading.Event() - async_result.set_exception(ImportError) + async_result.set_exception(ImportError()) - def wait_for_val(): + def wait_for_val() -> None: try: val = async_result.get() except ImportError: @@ -269,17 +266,17 @@ def wait_for_val(): assert lst == ["ooops"] th.join() - def test_linkage(self): + def test_linkage(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) cv = threading.Event() lst = [] - def add_on(): + def add_on() -> None: lst.append(True) - def wait_for_val(): + def wait_for_val() -> None: async_result.get() cv.set() @@ -294,13 +291,13 @@ def wait_for_val(): assert async_result.value == b"fred" th.join() - def test_linkage_not_ready(self): + def test_linkage_not_ready(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) lst = [] - def add_on(): + def add_on() -> None: lst.append(True) async_result.set("fred") @@ -308,13 +305,13 @@ def add_on(): async_result.rawlink(add_on) assert mock_handler.completion_queue.put.called - def test_link_and_unlink(self): + def test_link_and_unlink(self) -> None: mock_handler = Mock() async_result = self._makeOne(mock_handler) lst = [] - def add_on(): + def add_on() -> None: lst.append(True) async_result.rawlink(add_on) @@ -323,14 +320,14 @@ def add_on(): async_result.set("fred") assert not mock_handler.completion_queue.put.called - def test_captured_exception(self): + def test_captured_exception(self) -> None: from kazoo.handlers.utils import capture_exceptions mock_handler = Mock() async_result = self._makeOne(mock_handler) @capture_exceptions(async_result) - def exceptional_function(): + def exceptional_function() -> float: return 1 / 0 exceptional_function() @@ -338,7 +335,7 @@ def exceptional_function(): with pytest.raises(ZeroDivisionError): async_result.get() - def test_no_capture_exceptions(self): + def test_no_capture_exceptions(self) -> None: from kazoo.handlers.utils import capture_exceptions mock_handler = Mock() @@ -346,20 +343,20 @@ def test_no_capture_exceptions(self): lst = [] - def add_on(): + def add_on() -> None: lst.append(True) async_result.rawlink(add_on) @capture_exceptions(async_result) - def regular_function(): + def regular_function() -> bool: return True regular_function() assert not mock_handler.completion_queue.put.called - def test_wraps(self): + def test_wraps(self) -> None: from kazoo.handlers.utils import wrap mock_handler = Mock() @@ -367,20 +364,20 @@ def test_wraps(self): lst = [] - def add_on(result): + def add_on(result: AsyncResult) -> None: lst.append(result.get()) async_result.rawlink(add_on) @wrap(async_result) - def regular_function(): + def regular_function() -> str: return "hello" assert regular_function() == "hello" assert mock_handler.completion_queue.put.called assert async_result.get() == "hello" - def test_multiple_callbacks(self): + def test_multiple_callbacks(self) -> None: mockback1 = Mock(name="mockback1") mockback2 = Mock(name="mockback2") handler = self._makeHandler() diff --git a/kazoo/tests/test_utils.py b/kazoo/tests/test_utils.py index 96d484a88..b686cabd6 100644 --- a/kazoo/tests/test_utils.py +++ b/kazoo/tests/test_utils.py @@ -1,10 +1,18 @@ +from __future__ import annotations + +import ssl +import socket +import time +from kazoo.handlers import utils +from kazoo.handlers.utils import create_tcp_connection + import unittest from unittest.mock import patch import pytest try: - from kazoo.handlers.eventlet import green_socket as socket + from eventlet.green import socket as green_socket EVENTLET_HANDLER_AVAILABLE = True except ImportError: @@ -12,10 +20,7 @@ class TestCreateTCPConnection(unittest.TestCase): - def test_timeout_arg(self): - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, socket, time - + def test_timeout_arg(self) -> None: with patch.object(socket, "create_connection") as create_connection: with patch.object(utils, "_set_default_tcpsock_options"): # Ensure a gap between calls to time.time() does not result in @@ -30,10 +35,7 @@ def test_timeout_arg(self): timeout = call_args[0][1] assert timeout >= 0, "socket timeout must be nonnegative" - def test_ssl_server_hostname(self): - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, socket, ssl - + def test_ssl_server_hostname(self) -> None: with patch.object(utils, "_set_default_tcpsock_options"): with patch.object(ssl.SSLContext, "wrap_socket") as wrap_socket: create_tcp_connection( @@ -48,10 +50,7 @@ def test_ssl_server_hostname(self): server_hostname = call_args[1]["server_hostname"] assert server_hostname == "fakehostname" - def test_ssl_server_check_hostname(self): - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, socket, ssl - + def test_ssl_server_check_hostname(self) -> None: with patch.object(utils, "_set_default_tcpsock_options"): with patch.object( ssl.SSLContext, "wrap_socket", autospec=True @@ -69,9 +68,7 @@ def test_ssl_server_check_hostname(self): ssl_context = call_args[0][0] assert ssl_context.check_hostname - def test_ssl_server_check_hostname_config_validation(self): - from kazoo.handlers.utils import create_tcp_connection, socket - + def test_ssl_server_check_hostname_config_validation(self) -> None: with pytest.raises(ValueError): create_tcp_connection( socket, @@ -83,51 +80,45 @@ def test_ssl_server_check_hostname_config_validation(self): check_hostname=True, ) - def test_timeout_arg_eventlet(self): + def test_timeout_arg_eventlet(self) -> None: if not EVENTLET_HANDLER_AVAILABLE: pytest.skip("eventlet handler not available.") - from kazoo.handlers import utils - from kazoo.handlers.utils import create_tcp_connection, time - - with patch.object(socket, "create_connection") as create_connection: + with patch.object( + green_socket, "create_connection" + ) as create_connection: with patch.object(utils, "_set_default_tcpsock_options"): # Ensure a gap between calls to time.time() does not result in # create_connection being called with a negative timeout # argument. with patch.object(time, "time", side_effect=range(10)): create_tcp_connection( - socket, ("127.0.0.1", 2181), timeout=1.5 + green_socket, ("127.0.0.1", 2181), timeout=1.5 ) for call_args in create_connection.call_args_list: timeout = call_args[0][1] assert timeout >= 0, "socket timeout must be nonnegative" - def test_slow_connect(self): + def test_slow_connect(self) -> None: # Currently, create_tcp_connection will raise a socket timeout if it # takes longer than the specified "timeout" to create a connection. # In the future, "timeout" might affect only the created socket and not # the time it takes to create it. - from kazoo.handlers.utils import create_tcp_connection, socket, time - # Simulate a second passing between calls to check the current time. with patch.object(time, "time", side_effect=range(10)): with pytest.raises(socket.error): create_tcp_connection(socket, ("127.0.0.1", 2181), timeout=0.5) - def test_negative_timeout(self): - from kazoo.handlers.utils import create_tcp_connection, socket - + def test_negative_timeout(self) -> None: with pytest.raises(socket.error): create_tcp_connection(socket, ("127.0.0.1", 2181), timeout=-1) - def test_zero_timeout(self): + def test_zero_timeout(self) -> None: # Rather than pass '0' through as a timeout to # socket.create_connection, create_tcp_connection should raise # socket.error. This is because the socket library treats '0' as an # indicator to create a non-blocking socket. - from kazoo.handlers.utils import create_tcp_connection, socket, time # Simulate no time passing between calls to check the current time. with patch.object(time, "time", return_value=time.time()): diff --git a/kazoo/tests/test_watchers.py b/kazoo/tests/test_watchers.py index dc36ef614..e0852a183 100644 --- a/kazoo/tests/test_watchers.py +++ b/kazoo/tests/test_watchers.py @@ -1,32 +1,39 @@ +from __future__ import annotations + import time import threading import uuid +from typing import Any, List, Literal + import pytest from kazoo.exceptions import KazooException -from kazoo.protocol.states import EventType +from kazoo.protocol.states import EventType, WatchedEvent, ZnodeStat +from kazoo.recipe.watchers import PatientChildrenWatch + from kazoo.testing import KazooTestCase class KazooDataWatcherTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooDataWatcherTests, self).setUp() self.path = "/" + uuid.uuid4().hex self.client.ensure_path(self.path) - def test_data_watcher(self): + def test_data_watcher(self) -> None: update = threading.Event() - data = [True] + data: list[bool | bytes | None] = [True] # Make it a non-existent path self.path += "f" @self.client.DataWatch(self.path) - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() data.append(d) update.set() + return None update.wait(10) assert data == [None] @@ -37,9 +44,9 @@ def changed(d, stat): assert data[0] == b"fred" update.clear() - def test_data_watcher_once(self): + def test_data_watcher_once(self) -> None: update = threading.Event() - data = [True] + data: list[bool | bytes | None] = [True] # Make it a non-existent path self.path += "f" @@ -47,10 +54,11 @@ def test_data_watcher_once(self): dwatcher = self.client.DataWatch(self.path) @dwatcher - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() data.append(d) update.set() + return None update.wait(10) assert data == [None] @@ -59,23 +67,27 @@ def changed(d, stat): with pytest.raises(KazooException): @dwatcher - def func(d, stat): + def func(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() + return None - def test_data_watcher_with_event(self): + def test_data_watcher_with_event(self) -> None: # Test that the data watcher gets passed the event, if it # accepts three arguments update = threading.Event() - data = [True] + data: list[Literal[True] | WatchedEvent | None] = [True] # Make it a non-existent path self.path += "f" @self.client.DataWatch(self.path) - def changed(d, stat, event): + def changed( + d: bytes | None, stat: ZnodeStat | None, event: WatchedEvent | None + ) -> bool | None: data.pop() data.append(event) update.set() + return None update.wait(10) assert data == [None] @@ -83,17 +95,18 @@ def changed(d, stat, event): self.client.create(self.path, b"fred") update.wait(10) + assert data[0] is not None and data[0] is not True assert data[0].type == EventType.CREATED update.clear() - def test_func_style_data_watch(self): + def test_func_style_data_watch(self) -> None: update = threading.Event() - data = [True] + data: list[bytes | None | Literal[True]] = [True] # Make it a non-existent path path = self.path + "f" - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> None: data.pop() data.append(d) update.set() @@ -109,12 +122,12 @@ def changed(d, stat): assert data[0] == b"fred" update.clear() - def test_datawatch_across_session_expire(self): + def test_datawatch_across_session_expire(self) -> None: update = threading.Event() - data = [True] + data: list[bytes | None | Literal[True]] = [True] @self.client.DataWatch(self.path) - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> None: data.pop() data.append(d) update.set() @@ -128,21 +141,22 @@ def changed(d, stat): update.wait(25) assert data[0] == b"fred" - def test_func_stops(self): + def test_func_stops(self) -> None: update = threading.Event() - data = [True] + data: list[bytes | None | Literal[True]] = [True] self.path += "f" - fail_through = [] + fail_through: list[bool] = [] @self.client.DataWatch(self.path) - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> bool | None: data.pop() data.append(d) update.set() if fail_through: return False + return None update.wait(10) assert data == [None] @@ -161,21 +175,21 @@ def changed(d, stat): d, stat = self.client.get(self.path) assert d == b"asdfasdf" - def test_no_such_node(self): + def test_no_such_node(self) -> None: args = [] @self.client.DataWatch("/some/path") - def changed(d, stat): + def changed(d: bytes | None, stat: ZnodeStat | None) -> None: args.extend([d, stat]) assert args == [None, None] - def test_no_such_node_for_children_watch(self): + def test_no_such_node_for_children_watch(self) -> None: args = [] path = self.path + "/test_no_such_node_for_children_watch" update = threading.Event() - def changed(children): + def changed(children: list[str] | None) -> None: args.append(children) update.set() @@ -218,9 +232,9 @@ def changed(children): assert update.is_set() is False assert children_watch._stopped is True - def test_watcher_evaluating_to_false(self): - class WeirdWatcher(list): - def __call__(self, *args): + def test_watcher_evaluating_to_false(self) -> None: + class WeirdWatcher(List[Any]): + def __call__(self, *args: Any) -> None: self.called = True watcher = WeirdWatcher() @@ -228,14 +242,14 @@ def __call__(self, *args): self.client.set(self.path, b"mwahaha") assert watcher.called is True - def test_watcher_repeat_delete(self): + def test_watcher_repeat_delete(self) -> None: a = [] ev = threading.Event() self.client.delete(self.path) @self.client.DataWatch(self.path) - def changed(val, stat): + def changed(val: bytes | None, stat: ZnodeStat | None) -> None: a.append(val) ev.set() @@ -258,14 +272,14 @@ def changed(val, stat): ev.clear() assert a == [None, b"blah", None, b"blah"] - def test_watcher_with_closing(self): + def test_watcher_with_closing(self) -> None: a = [] ev = threading.Event() self.client.delete(self.path) @self.client.DataWatch(self.path) - def changed(val, stat): + def changed(val: bytes | None, stat: ZnodeStat | None) -> None: a.append(val) ev.set() @@ -280,17 +294,18 @@ def changed(val, stat): class KazooChildrenWatcherTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooChildrenWatcherTests, self).setUp() self.path = "/" + uuid.uuid4().hex self.client.ensure_path(self.path) - def test_child_watcher(self): + def test_child_watcher(self) -> None: update = threading.Event() all_children = ["fred"] @self.client.ChildrenWatch(self.path) - def changed(children): + def changed(children: list[str] | None) -> None: + assert children is not None while all_children: all_children.pop() all_children.extend(children) @@ -309,16 +324,17 @@ def changed(children): update.wait(10) assert sorted(all_children) == ["george", "smith"] - def test_child_watcher_once(self): + def test_child_watcher_once(self) -> None: update = threading.Event() all_children = ["fred"] cwatch = self.client.ChildrenWatch(self.path) @cwatch - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -329,18 +345,21 @@ def changed(children): with pytest.raises(KazooException): @cwatch - def changed_again(children): + def changed_again(children: list[str] | None) -> None: update.set() - def test_child_watcher_with_event(self): + def test_child_watcher_with_event(self) -> None: update = threading.Event() - events = [True] + events: list[WatchedEvent | None | Literal[True]] = [True] @self.client.ChildrenWatch(self.path, send_event=True) - def changed(children, event): + def changed( + children: list[str] | None, event: WatchedEvent | None + ) -> bool | None: events.pop() events.append(event) update.set() + return None update.wait(10) assert events == [None] @@ -348,16 +367,19 @@ def changed(children, event): self.client.create(self.path + "/" + "smith") update.wait(10) + assert events[0] is not None + assert events[0] is not True assert events[0].type == EventType.CHILD update.clear() - def test_func_style_child_watcher(self): + def test_func_style_child_watcher(self) -> None: update = threading.Event() all_children = ["fred"] - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -376,20 +398,22 @@ def changed(children): update.wait(10) assert sorted(all_children) == ["george", "smith"] - def test_func_stops(self): + def test_func_stops(self) -> None: update = threading.Event() all_children = ["fred"] - fail_through = [] + fail_through: list[bool] = [] @self.client.ChildrenWatch(self.path) - def changed(children): + def changed(children: list[str] | None) -> bool | None: + assert children is not None while all_children: all_children.pop() all_children.extend(children) update.set() if fail_through: return False + return None # ? True? update.wait(10) assert all_children == [] @@ -405,19 +429,21 @@ def changed(children): update.wait(0.5) assert all_children == ["smith"] - def test_child_watcher_remove_session_watcher(self): + def test_child_watcher_remove_session_watcher(self) -> None: update = threading.Event() all_children = ["fred"] - fail_through = [] + fail_through: list[bool] = [] - def changed(children): + def changed(children: list[str] | None) -> bool | None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() if fail_through: return False + return None # ? children_watch = self.client.ChildrenWatch(self.path, changed) session_watcher = children_watch._session_watcher @@ -439,14 +465,15 @@ def changed(children): assert session_watcher not in self.client.state_listeners assert all_children == ["smith"] - def test_child_watch_session_loss(self): + def test_child_watch_session_loss(self) -> None: update = threading.Event() all_children = ["fred"] @self.client.ChildrenWatch(self.path) - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -464,14 +491,15 @@ def changed(children): update.wait(20) assert sorted(all_children) == ["george", "smith"] - def test_child_stop_on_session_loss(self): + def test_child_stop_on_session_loss(self) -> None: update = threading.Event() all_children = ["fred"] @self.client.ChildrenWatch(self.path, allow_session_lost=False) - def changed(children): + def changed(children: list[str] | None) -> None: while all_children: all_children.pop() + assert children is not None all_children.extend(children) update.set() @@ -495,16 +523,14 @@ def changed(children): class KazooPatientChildrenWatcherTests(KazooTestCase): - def setUp(self): + def setUp(self) -> None: super(KazooPatientChildrenWatcherTests, self).setUp() self.path = "/" + uuid.uuid4().hex - def _makeOne(self, *args, **kwargs): - from kazoo.recipe.watchers import PatientChildrenWatch - + def _makeOne(self, *args: Any, **kwargs: Any) -> PatientChildrenWatch: return PatientChildrenWatch(*args, **kwargs) - def test_watch(self): + def test_watch(self) -> None: self.client.ensure_path(self.path) watcher = self._makeOne(self.client, self.path, 0.1) result = watcher.start() @@ -516,7 +542,7 @@ def test_watch(self): asy.get(timeout=1) assert asy.ready() is True - def test_exception(self): + def test_exception(self) -> None: from kazoo.exceptions import NoNodeError watcher = self._makeOne(self.client, self.path, 0.1) @@ -525,7 +551,7 @@ def test_exception(self): with pytest.raises(NoNodeError): result.get() - def test_watch_iterations(self): + def test_watch_iterations(self) -> None: self.client.ensure_path(self.path) watcher = self._makeOne(self.client, self.path, 0.5) result = watcher.start() diff --git a/kazoo/tests/util.py b/kazoo/tests/util.py index f72987180..c62112381 100644 --- a/kazoo/tests/util.py +++ b/kazoo/tests/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + ############################################################################## # # Copyright Zope Foundation and Contributors. @@ -16,46 +18,50 @@ import os import time +from typing import Any, Callable, Type + CI = os.environ.get("CI", False) -CI_ZK_VERSION = CI and os.environ.get("ZOOKEEPER_VERSION", None) -if CI_ZK_VERSION: - if "-" in CI_ZK_VERSION: - # Ignore pre-release markers like -alpha - CI_ZK_VERSION = CI_ZK_VERSION.split("-")[0] - CI_ZK_VERSION = tuple([int(n) for n in CI_ZK_VERSION.split(".")]) +CI_ZK_VERSION: tuple[int, ...] = tuple() +if CI: + has_version = os.environ.get("ZOOKEEPER_VERSION", "") + if has_version: + if "-" in has_version: + # Ignore pre-release markers like -alpha + has_version = has_version.split("-")[0] + CI_ZK_VERSION = tuple([int(n) for n in has_version.split(".")]) class Handler(logging.Handler): - def __init__(self, *names, **kw): + def __init__(self, *names: Any, **kw: Any): logging.Handler.__init__(self) self.names = names - self.records = [] + self.records: list[Any] = [] self.setLoggerLevel(**kw) - def setLoggerLevel(self, level=1): + def setLoggerLevel(self, level: int = 1) -> None: self.level = level - self.oldlevels = {} + self.oldlevels: dict[str, int] = {} - def emit(self, record): + def emit(self, record: Any) -> None: self.records.append(record) - def clear(self): + def clear(self) -> None: del self.records[:] - def install(self): + def install(self) -> None: for name in self.names: logger = logging.getLogger(name) self.oldlevels[name] = logger.level logger.setLevel(self.level) logger.addHandler(self) - def uninstall(self): + def uninstall(self) -> None: for name in self.names: logger = logging.getLogger(name) logger.setLevel(self.oldlevels[name]) logger.removeHandler(self) - def __str__(self): + def __str__(self) -> str: return "\n".join( [ ( @@ -78,7 +84,7 @@ def __str__(self): class InstalledHandler(Handler): - def __init__(self, *names, **kw): + def __init__(self, *names: Any, **kw: Any): Handler.__init__(self, *names, **kw) self.install() @@ -92,11 +98,11 @@ class TimeOutWaitingFor(Exception): def __init__( self, - timeout=None, - wait=None, - exception=None, - getnow=(lambda: time.monotonic), - getsleep=(lambda: time.sleep), + timeout: int | None = None, + wait: float | None = None, + exception: Type[Exception] = TimeOutWaitingFor, + getnow: Callable[[], Callable[[], float]] = (lambda: time.monotonic), + getsleep: Callable[[], Callable[[float], None]] = (lambda: time.sleep), ): if timeout is not None: self.timeout = timeout @@ -104,15 +110,21 @@ def __init__( if wait is not None: self.wait = wait - if exception is not None: - self.TimeOutWaitingFor = exception + self.exception = exception self.getnow = getnow self.getsleep = getsleep - def __call__(self, func=None, timeout=None, wait=None, message=None): - if func is None: - return lambda func: self(func, timeout, wait, message) + def __call__( + self, + func: Callable[[], Any], # | None = None, + timeout: float | None = None, + wait: float | None = None, + message: str | None = None, + ) -> None: + # if func is None: + # # Seriously WTF? + # return lambda func: self(func, timeout, wait, message) if func(): return @@ -131,9 +143,8 @@ def __call__(self, func=None, timeout=None, wait=None, message=None): if func(): return if now() > deadline: - raise self.TimeOutWaitingFor( - message or func.__doc__ or func.__name__ - ) + # raise self.TimeOutWaitingFor( + raise self.exception(message or func.__doc__ or func.__name__) wait = Wait() diff --git a/pyproject.toml b/pyproject.toml index db3890c58..869ed27c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,18 +6,9 @@ requires = [ [tool.black] line-length = 79 -target-version = ['py37', 'py38', 'py39', 'py310'] +# We need a later version of black for 312-314 +target-version = ['py38', 'py39', 'py310', 'py311'] include = '\.pyi?$' -# 'extend-exclude' excludes files or directories in addition to the defaults -# A regex preceded with ^/ will apply only to files and directories -# in the root of the project. -# ( -# ^/foo.py # exclude a file named foo.py in the root of the project. -# | .*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the -# ) -# # project. -extend-exclude = ''' -''' [tool.pytest.ini_options] addopts = "-ra -v --color=yes" @@ -42,8 +33,12 @@ ignore_missing_imports = false # Disallow dynamic typing disallow_any_unimported = true disallow_any_expr = false -disallow_any_decorated = true -disallow_any_explicit = true +# FIXME disallow_any_decorated is disabled because it produces an error for almost +# every decorated function (at least in python3.8) +disallow_any_decorated = false +# FIXME disallow_any_explicit is disabled because it produces a lot of errors +# which are currently hard to fix, and we want to avoid code changes as much as possible. +disallow_any_explicit = false disallow_any_generics = true disallow_subclassing_any = true @@ -68,6 +63,8 @@ warn_unreachable = true # Miscellaneous strictness flags allow_untyped_globals = false allow_redefinition = false +# FIXME Disabled for now as we need to fix the X509 import. +#enable_error_code = ["deprecated"] local_partial_types = true implicit_reexport = false strict_concatenate = true @@ -81,71 +78,25 @@ hide_error_codes = false pretty = true color_output = true error_summary = true -show_absolute_path = true +show_absolute_path = false # Miscellaneous warn_unused_configs = true verbosity = 0 -# FIXME: As type annotations are introduced, please remove the appropriate -# ignore_errors flag below. New modules should NOT be added here! +# This is a temporary workaround for the fact that mypy can +# produce different errors in 3.8 and 3.14, and I want to avoid code changes +# as much as possible. +disable_error_code = [ 'unused-ignore' ] [[tool.mypy.overrides]] -module = [ - 'kazoo.client', - 'kazoo.exceptions', - 'kazoo.handlers.eventlet', - 'kazoo.handlers.gevent', - 'kazoo.handlers.threading', - 'kazoo.handlers.utils', - 'kazoo.hosts', - 'kazoo.interfaces', - 'kazoo.loggingsupport', - 'kazoo.protocol.connection', - 'kazoo.protocol.paths', - 'kazoo.protocol.serialization', - 'kazoo.protocol.states', - 'kazoo.recipe.barrier', - 'kazoo.recipe.cache', - 'kazoo.recipe.counter', - 'kazoo.recipe.election', - 'kazoo.recipe.lease', - 'kazoo.recipe.lock', - 'kazoo.recipe.partitioner', - 'kazoo.recipe.party', - 'kazoo.recipe.queue', - 'kazoo.recipe.watchers', - 'kazoo.retry', - 'kazoo.security', - 'kazoo.testing.common', - 'kazoo.testing.harness', - 'kazoo.tests.conftest', - 'kazoo.tests.test_barrier', - 'kazoo.tests.test_build', - 'kazoo.tests.test_cache', - 'kazoo.tests.test_client', - 'kazoo.tests.test_connection', - 'kazoo.tests.test_counter', - 'kazoo.tests.test_election', - 'kazoo.tests.test_eventlet_handler', - 'kazoo.tests.test_exceptions', - 'kazoo.tests.test_gevent_handler', - 'kazoo.tests.test_hosts', - 'kazoo.tests.test_interrupt', - 'kazoo.tests.test_lease', - 'kazoo.tests.test_lock', - 'kazoo.tests.test_partitioner', - 'kazoo.tests.test_party', - 'kazoo.tests.test_paths', - 'kazoo.tests.test_queue', - 'kazoo.tests.test_retry', - 'kazoo.tests.test_sasl', - 'kazoo.tests.test_security', - 'kazoo.tests.test_selectors_select', - 'kazoo.tests.test_threading_handler', - 'kazoo.tests.test_utils', - 'kazoo.tests.test_watchers', - 'kazoo.tests.util', - 'kazoo.version' -] -ignore_errors = true + module = ["eventlet.*"] + follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["gevent.thread"] + follow_untyped_imports = true + +[[tool.mypy.overrides]] + module = ["puresasl.*"] + follow_untyped_imports = true diff --git a/setup.cfg b/setup.cfg index e3a7b8805..5bb8c0fb9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,19 +43,26 @@ zip_safe = false include_package_data = true packages = find: +[options.package_data] +kazoo = py.typed + [aliases] release = sdist bdist_wheel [egg_info] tag_build = dev -[bdist_wheel] -universal = true - [options.extras_require] dev = + black flake8 +# I have to pin this HERE because I can't persuade vscode to +# look in constraints.txt +other = + pyOpenSSL<26.2.0 + typing-extensions + test = objgraph pytest @@ -64,7 +71,7 @@ test = gevent>=1.2 ; implementation_name!='pypy' eventlet>=0.17.1 ; implementation_name!='pypy' pyjks - pyOpenSSL + %(other)s eventlet = eventlet>=0.17.1 @@ -80,8 +87,11 @@ docs = sphinx-autodoc-typehints>=1 typing = - mypy>=0.991 - pyOpenSSL + mypy==1.14.1 + types-gevent + types-mock + types-objgraph + types-pyjks alldeps = %(dev)s @@ -89,4 +99,5 @@ alldeps = %(gevent)s %(sasl)s %(docs)s + %(other)s %(typing)s diff --git a/tox.ini b/tox.ini index 567acd29c..054fd9d04 100644 --- a/tox.ini +++ b/tox.ini @@ -62,6 +62,8 @@ basepython = python3 extras = alldeps deps = mypy - mypy: types-mock + types-mock + pyopenssl + pytest usedevelop = True commands = mypy --config-file {toxinidir}/pyproject.toml {toxinidir}/kazoo