Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/unit/_autoscaling/test_autoscaled_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from crawlee._autoscaling._types import LoadRatioInfo, SystemInfo
from crawlee._types import ConcurrencySettings
from crawlee._utils.time import measure_time
from tests.unit.utils import wait_for_condition
from tests.unit.utils import poll_until_condition

if TYPE_CHECKING:
from collections.abc import Awaitable
Expand Down Expand Up @@ -192,20 +192,20 @@ def get_historical_system_info() -> SystemInfo:

try:
# Wait until concurrency scales up above 1.
await wait_for_condition(lambda: pool.desired_concurrency > 1, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency > 1, timeout=5.0)

# Wait until concurrency reaches maximum.
await wait_for_condition(lambda: pool.desired_concurrency == 4, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency == 4, timeout=5.0)

# Multiple concurrent workers should have completed more tasks than a single worker could.
await wait_for_condition(lambda: done_count > 10, timeout=5.0)
assert await poll_until_condition(lambda: done_count > 10, timeout=5.0)

# Simulate CPU overload and wait for the pool to scale down.
overload_active = True
await wait_for_condition(lambda: pool.desired_concurrency < 4, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency < 4, timeout=5.0)

# Wait until the pool scales all the way down to minimum.
await wait_for_condition(lambda: pool.desired_concurrency == 1, timeout=5.0)
assert await poll_until_condition(lambda: pool.desired_concurrency == 1, timeout=5.0)
finally:
pool_run_task.cancel()
with suppress(asyncio.CancelledError):
Expand Down
122 changes: 104 additions & 18 deletions tests/unit/utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,120 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import sys
from typing import TYPE_CHECKING
import time
from typing import TYPE_CHECKING, TypeVar, cast, overload

import pytest

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Awaitable, Callable

logger = logging.getLogger(__name__)

T = TypeVar('T')

run_alone_on_mac = pytest.mark.run_alone if sys.platform == 'darwin' else lambda x: x


async def wait_for_condition(
condition: Callable[[], bool],
async def _maybe_await(value: Awaitable[T] | T) -> T:
"""Await `value` if it is awaitable, otherwise return it unchanged.

Lets `call_with_exp_backoff` and `poll_until_condition` accept both sync and async callables.
"""
if inspect.isawaitable(value):
return await cast('Awaitable[T]', value)
return cast('T', value)


@overload
async def call_with_exp_backoff(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool] = ...,
*,
max_retries: int = ...,
base_delay: float = ...,
) -> T: ...
@overload
async def call_with_exp_backoff(
fn: Callable[[], T],
condition: Callable[[T], bool] = ...,
*,
max_retries: int = ...,
base_delay: float = ...,
) -> T: ...
async def call_with_exp_backoff(
fn: Callable[[], Awaitable[T] | T],
condition: Callable[[T], bool] = bool,
*,
max_retries: int = 5,
base_delay: float = 1.0,
) -> T:
"""Call `fn`, retrying with exponential backoff until `condition(result)` is True.

Calls `fn` and checks whether `condition` holds for its result. If it does not, `fn` is retried up to
`max_retries` times, sleeping `base_delay * 2 ** attempt` seconds before each retry. The last result is
returned regardless of whether the condition was ever satisfied, so the caller can run its own assertion.

This is useful for eventually-consistent state where the expected value may take a moment to appear. The
default condition checks for a truthy result. Pass `max_retries=0` to call `fn` exactly once without retries.

Unlike `poll_until_condition`, the delay between attempts grows exponentially rather than staying constant.
"""
result = await _maybe_await(fn())
for attempt in range(max_retries):
if condition(result):
return result
delay = base_delay * 2**attempt
logger.info(
'Condition not met for %r, retrying in %ss (attempt %d/%d).', result, delay, attempt + 1, max_retries
)
await asyncio.sleep(delay)
result = await _maybe_await(fn())
return result


@overload
async def poll_until_condition(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool] = ...,
*,
timeout: float = ...,
poll_interval: float = ...,
) -> T: ...
@overload
async def poll_until_condition(
fn: Callable[[], T],
condition: Callable[[T], bool] = ...,
*,
timeout: float = 5.0,
timeout: float = ...,
poll_interval: float = ...,
) -> T: ...
async def poll_until_condition(
fn: Callable[[], Awaitable[T] | T],
condition: Callable[[T], bool] = bool,
*,
timeout: float = 5,
poll_interval: float = 0.05,
) -> None:
"""Poll `condition` until it returns True, or raise `AssertionError` on timeout.
) -> T:
"""Poll `fn` until `condition(result)` is True or the timeout expires.

Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed.
Returns the last polled result regardless of whether the condition was met, so the caller can run its own
assertion. The default condition checks for a truthy result.

Args:
condition: A callable that returns True when the desired state is reached.
timeout: Maximum time in seconds to wait before raising.
poll_interval: Time in seconds between condition checks.
Use this instead of a fixed `asyncio.sleep` when waiting for some state to settle (e.g. autoscaling
concurrency) that may take a variable amount of time. Unlike `call_with_exp_backoff`, the interval between
polls stays constant.
"""
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
while loop.time() < deadline:
if condition():
return
await asyncio.sleep(poll_interval)
raise AssertionError(f'Condition not met within {timeout}s')
deadline = time.monotonic() + timeout
result = await _maybe_await(fn())
while not condition(result):
remaining = deadline - time.monotonic()
if remaining <= 0:
break
await asyncio.sleep(min(poll_interval, remaining))
result = await _maybe_await(fn())
return result
Loading