From 55be4ad1e20ea74127f523141ca9ad3dafba5f4a Mon Sep 17 00:00:00 2001 From: Vlada Dusek Date: Fri, 29 May 2026 18:06:56 +0200 Subject: [PATCH] test: unify polling helpers with apify-sdk-python Add generic call_with_exp_backoff and poll_until_condition test helpers, matching the signatures shared with apify-sdk-python and crawlee-python. Both accept sync or async callables (via a small _maybe_await helper, typed with overloads), which fits this repo's sync/async client duality. Rebuild collect_iterate_until_present on top of poll_until_condition: the iterate/drain logic stays, but the polling loop is delegated to the shared helper. The contract is unchanged - the last collected list is returned so callers keep asserting with their own failure messages. --- tests/integration/_utils.py | 134 ++++++++++++++++++++++++++++++++---- 1 file changed, 121 insertions(+), 13 deletions(-) diff --git a/tests/integration/_utils.py b/tests/integration/_utils.py index 4412808a..6351685b 100644 --- a/tests/integration/_utils.py +++ b/tests/integration/_utils.py @@ -1,17 +1,21 @@ from __future__ import annotations import asyncio +import inspect +import logging import secrets import string import time from collections.abc import AsyncIterator, Iterator from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast, overload import pytest if TYPE_CHECKING: - from collections.abc import Callable, Coroutine + from collections.abc import Awaitable, Callable, Coroutine + +logger = logging.getLogger(__name__) # Environment variable names for test configuration TOKEN_ENV_VAR = 'APIFY_TEST_USER_API_TOKEN' @@ -119,6 +123,107 @@ async def maybe_sleep(seconds: float, *, is_async: bool) -> None: time.sleep(seconds) # noqa: ASYNC251 +async def _maybe_await(value: Awaitable[T] | T) -> T: + """Await `value` if it is awaitable, otherwise return it unchanged. + + Lets `call_with_exp_backoff` and `poll_until_condition` accept both sync and async callables. + """ + if inspect.isawaitable(value): + return await cast('Awaitable[T]', value) + return cast('T', value) + + +@overload +async def call_with_exp_backoff( + fn: Callable[[], Awaitable[T]], + condition: Callable[[T], bool] = ..., + *, + max_retries: int = ..., + base_delay: float = ..., +) -> T: ... +@overload +async def call_with_exp_backoff( + fn: Callable[[], T], + condition: Callable[[T], bool] = ..., + *, + max_retries: int = ..., + base_delay: float = ..., +) -> T: ... +async def call_with_exp_backoff( + fn: Callable[[], Awaitable[T] | T], + condition: Callable[[T], bool] = bool, + *, + max_retries: int = 5, + base_delay: float = 1.0, +) -> T: + """Call `fn`, retrying with exponential backoff until `condition(result)` is True. + + Calls `fn` and checks whether `condition` holds for its result. If it does not, `fn` is retried up to + `max_retries` times, sleeping `base_delay * 2 ** attempt` seconds before each retry. The last result is + returned regardless of whether the condition was ever satisfied, so the caller can run its own assertion. + + This is useful for eventually-consistent APIs where a freshly created resource may take a moment to become + visible. The default condition checks for a truthy result. Pass `max_retries=0` to call `fn` exactly once. + + Unlike `poll_until_condition`, the delay between attempts grows exponentially rather than staying constant. + """ + result = await _maybe_await(fn()) + for attempt in range(max_retries): + if condition(result): + return result + delay = base_delay * 2**attempt + logger.info( + 'Condition not met for %r, retrying in %ss (attempt %d/%d).', result, delay, attempt + 1, max_retries + ) + await asyncio.sleep(delay) + result = await _maybe_await(fn()) + return result + + +@overload +async def poll_until_condition( + fn: Callable[[], Awaitable[T]], + condition: Callable[[T], bool] = ..., + *, + timeout: float = ..., + poll_interval: float = ..., +) -> T: ... +@overload +async def poll_until_condition( + fn: Callable[[], T], + condition: Callable[[T], bool] = ..., + *, + timeout: float = ..., + poll_interval: float = ..., +) -> T: ... +async def poll_until_condition( + fn: Callable[[], Awaitable[T] | T], + condition: Callable[[T], bool] = bool, + *, + timeout: float = 5, + poll_interval: float = 1, +) -> T: + """Poll `fn` until `condition(result)` is True or the timeout expires. + + Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed. + Returns the last polled result regardless of whether the condition was met, so the caller can run its own + assertion. The default condition checks for a truthy result. + + Use this instead of a fixed `asyncio.sleep` when waiting for eventually-consistent state (e.g. a freshly + created resource appearing in a listing) that may take a variable amount of time to propagate. Unlike + `call_with_exp_backoff`, the interval between polls stays constant. + """ + deadline = time.monotonic() + timeout + result = await _maybe_await(fn()) + while not condition(result): + remaining = deadline - time.monotonic() + if remaining <= 0: + break + await asyncio.sleep(min(poll_interval, remaining)) + result = await _maybe_await(fn()) + return result + + async def collect_iterate_until_present( iterator_factory: Callable[[], Iterator[_HasIdT] | AsyncIterator[_HasIdT]], expected_ids: set[str], @@ -132,7 +237,7 @@ async def collect_iterate_until_present( Handles eventual consistency on listing endpoints: under parallel load a freshly created resource may not appear in the listing for a short window. Each attempt - builds a fresh iterator via `iterator_factory`, drains it, and breaks early once + builds a fresh iterator via `iterator_factory`, drains it, and stops early once `expected_ids` is a subset of the collected items' `.id` values. The most recent collection is returned regardless of whether the condition was met, so the caller can run its own assertion with a helpful failure message. @@ -141,19 +246,17 @@ async def collect_iterate_until_present( iterator_factory: No-arg callable returning a fresh iterator on each call. expected_ids: IDs that must all appear in the collected items. item_type: Asserted to match the runtime type of each yielded item. - is_async: Whether the iterator is async (and so are sleeps). + is_async: Whether the iterator is async. max_attempts: Maximum number of polling rounds. - interval: Seconds to sleep before each attempt. + interval: Seconds to sleep between attempts. Returns: The most recently collected items. """ - collected: list[_HasIdT] = [] - for attempt in range(max_attempts): - if attempt > 0: - await maybe_sleep(interval, is_async=is_async) + + async def drain() -> list[_HasIdT]: iterator = iterator_factory() - collected = [] + collected: list[_HasIdT] = [] if is_async: assert isinstance(iterator, AsyncIterator) async for item in iterator: @@ -164,9 +267,14 @@ async def collect_iterate_until_present( for item in iterator: assert isinstance(item, item_type) collected.append(item) - if expected_ids.issubset(item.id for item in collected): - break - return collected + return collected + + return await poll_until_condition( + drain, + lambda collected: expected_ids.issubset(item.id for item in collected), + timeout=max_attempts * interval, + poll_interval=interval, + ) # ============================================================================