From d2bf71f3c81d5f4d789c230eff00d08493fb4669 Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Mon, 25 May 2026 10:38:26 -0500 Subject: [PATCH 1/2] feat(waterdata): Add async parallel chunker over httpx.AsyncClient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a parallel fan-out path to `multi_value_chunked`. When `API_USGS_CONCURRENT` resolves to >1 (default: 16), the decorator runs the sub-requests of an over-budget plan concurrently under one shared `httpx.AsyncClient`, instead of issuing them serially. Falls back to the serial sync path (with a one-time UserWarning) when no async fetch sibling is wired or when an asyncio event loop is already running (Jupyter, IPython, async apps — `asyncio.run` would otherwise raise). Architecture (`dataretrieval/waterdata/chunking.py`): * `_fan_out_async(plan, fetch_once, fetch_async, *, max_concurrent)` is the orchestrator: it dispatches every sub-request concurrently via `asyncio.gather(return_exceptions=True)`. Completed pairs survive a sibling's transient failure, so partial state stays recoverable through `ChunkedCall.resume()` on the sync path. * Failure precedence in the gather: 1. Cancellation/interrupt signals (CancelledError, KeyboardInterrupt, SystemExit) propagate unmodified — never wrapped as transients. Cancellation is asyncio's abort signal; rewriting it as ChunkInterrupted would silently consume the user's stop request. 2. Recognized transients (RateLimited, ServiceUnavailable, bare httpx.HTTPError) wrap as ChunkInterrupted so the user gets a resumable handle even when a non-transient bug landed earlier in submission order. 3. Otherwise raise the first failure in submission order, preserving its type. * `_execute_in_parallel` owns the sync→async bridge: `asyncio.run` dispatch with the `fetch_async is None` and running-event-loop fallbacks (each a one-time UserWarning, then serial). * `_publish_async_client` / `get_active_async_client` / `_chunked_async_client` ContextVar let async paginated-loop helpers (`_walk_pages_async`, `_paginate_async`) reuse one `AsyncClient` connection pool across every concurrent sub-request. Wiring (`dataretrieval/waterdata/utils.py`): * `_walk_pages_async`, `_paginate_async`, `_client_for_async`, `_fetch_once_async` — async siblings of the sync paginate path, sharing the same `parse_response` / `follow_up` callbacks and the `_ogc_parse_response` parser. * The `@chunking.multi_value_chunked(fetch_async=_fetch_once_async)` decorator on `_fetch_once` wires the async sibling so the parallel path is available to every Water Data OGC getter. * `ChunkedCall.record()` encapsulates the completion write so the serial loop and the parallel fan-out share it; `_chunks` is a sparse index map so a parallel partial-failure resumes correctly via the sync path. Concurrency cap (`API_USGS_CONCURRENT`): * Integer N >= 1: bounded fan-out (semaphore-gated, N=1 forces serial sync). Default 16 — the server-friendly sweet spot. * `unbounded`: no per-call cap (`Semaphore(sys.maxsize)`). * Unset: default 16. Retries (`API_USGS_RETRIES`, default 4; `0` disables): each sub-request is retried on a transient failure with exponential backoff + full jitter, so a large fan-out completes through the AWS API Gateway's burst throttling and the occasional backend straggler instead of aborting on the first 429/5xx/timeout. * `RetryPolicy` — a frozen value object owning the timing decisions (`from_env`, `should_retry`, `backoff`). Full jitter (`random.uniform(0, ceiling)`) de-correlates the concurrent retries so they don't re-burst in lockstep. A server `Retry-After` overrides the computed backoff, unless it exceeds `retry_after_cap` (60s) — a multi-minute quota-window reset escalates to the resumable interruption instead of blocking the call inline. * `_retryable` — chain-walking predicate, narrower than `_classify_chunk_error`: retries `RateLimited` / `ServiceUnavailable` / `httpx.TransportError` but NOT `httpx.InvalidURL`. * `_retry_sync` / `_retry_async` drivers wrap the per-sub-request fetch at both seams (`ChunkedCall._issue`, `_fan_out_async.track`); the async retry runs inside the semaphore, so a backing-off chunk holds its slot and effective concurrency shrinks under throttling. On exhaustion the last exception re-raises into the existing `wrap_failure` path, so `.resume()` stays the escape hatch. * `ProgressReporter.note_retry` surfaces the backoff on the status line ("retrying (attempt N, waiting Ns)"), cleared by the next page. Test scaffolding: `tests/conftest.py` extends the `_serial_chunker` autouse fixture to pin `API_USGS_CONCURRENT=1` and `API_USGS_RETRIES=0` so the existing mocked suite stays on the deterministic serial path with transients surfacing immediately; async and retry tests opt back in by re-setting the env vars inside their body. Tests: async-path coverage in `tests/waterdata_chunking_test.py` (one-call-per-sub-request, mid-fan-out transient yields resumable ChunkInterrupted, fallback-to-serial parametrized over running-loop and missing-fetch_async, cancellation-wins-over- transient-sibling regression), plus retry coverage (policy math/jitter bounds, `_retryable` taxonomy, sync+async transient-then-success, exhausted-still-resumable, long-`Retry-After` escalation). `tests/waterdata_progress_test.py` adds progress integration for `_fan_out_async` / `_paginate_async` and the `note_retry` render/clear. `tests/waterdata_utils_test.py` adds a `_walk_pages_async` initial-parse-error test. Test suite: 435 passing, 2 skipped (mocked); ruff clean. Co-Authored-By: Claude Opus 4.7 (1M context) --- dataretrieval/waterdata/_progress.py | 16 + dataretrieval/waterdata/chunking.py | 606 +++++++++++++++++++++++++-- dataretrieval/waterdata/utils.py | 147 ++++++- tests/conftest.py | 31 +- tests/waterdata_chunking_test.py | 420 +++++++++++++++++++ tests/waterdata_progress_test.py | 141 +++++++ tests/waterdata_utils_test.py | 31 ++ 7 files changed, 1356 insertions(+), 36 deletions(-) diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index 7263d555..7104f3af 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -121,6 +121,9 @@ def __init__( # The hourly request quota (``x-ratelimit-limit``), shown as the # denominator when the server reports it. self.rate_limit: str | None = None + # Transient note shown while a sub-request backs off before a + # retry; cleared by the next page/chunk so it doesn't linger. + self.retry_note: str | None = None self._last_len = 0 # Whether anything was actually written to the stream — drives whether # close() needs a terminating newline. (``current_chunk`` is a poor @@ -140,6 +143,7 @@ def start_chunk(self, index: int) -> None: avoids a premature "0 pages" frame before the first page arrives. """ self.current_chunk = index + self.retry_note = None if self.total_chunks > 1: self._render() @@ -147,6 +151,16 @@ def add_page(self, rows: int = 0) -> None: """Record one fetched page carrying ``rows`` rows and redraw.""" self.pages += 1 self.rows += int(rows) + self.retry_note = None + self._render() + + def note_retry(self, *, attempt: int, wait: float) -> None: + """Show that a sub-request is backing off before retry ``attempt``. + + Cleared by the next :meth:`add_page` / :meth:`start_chunk` so the + line returns to normal progress once the retry succeeds. + """ + self.retry_note = f"retrying (attempt {attempt}, waiting {wait:.0f}s)" self._render() def set_rate_remaining( @@ -179,6 +193,8 @@ def _format(self) -> str: else: segment = f"{remaining} requests remaining" parts.append(segment) + if self.retry_note is not None: + parts.append(self.retry_note) if self.service: return f"Retrieving: {self.service} · " + " · ".join(parts) return "Progress: " + " · ".join(parts) diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 36ee24fd..1e3b429d 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -9,14 +9,34 @@ sub-request URL fits. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. +Concurrency: when ``API_USGS_CONCURRENT`` is set to an integer N > 1 +(or the literal ``unbounded``), ``multi_value_chunked`` fans the plan +out across ``N`` async coroutines sharing one ``httpx.AsyncClient`` +instead of issuing sub-requests serially. ``N=1`` forces the +synchronous path. The default (16) is the server-friendly sweet +spot; higher values can trip USGS burst-protection 5xx in practice. +The wrapper falls back to the serial path (with a ``UserWarning``) +when an asyncio event loop is already running (Jupyter / IPython / +async apps) or when no async fetch sibling is wired into the +decorator. + +Retries: each sub-request is retried on a transient failure (429, +5xx, connect/read timeout) with exponential backoff + full jitter, +honoring a server ``Retry-After`` when present. ``API_USGS_RETRIES`` +sets the cap (default 4; ``0`` disables). A ``Retry-After`` longer +than the per-call ceiling isn't slept off inline — it escalates to +the resumable interruption below so a multi-minute quota-window +reset doesn't block the call. + Interruption: any mid-stream transient failure (429, 5xx) surfaces as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, ``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a ``ChunkedCall`` handle that owns the already-completed sub-request -state. Call ``.call.resume()`` once the underlying condition -clears; only the still-pending sub-requests are re-issued. -``Retry-After`` (when the server sets it) is surfaced on the -exception as ``.retry_after``. +state (sparse-indexed on the parallel path, contiguous-prefix on +the serial path). Call ``.call.resume()`` once the underlying +condition clears; only the still-pending sub-requests are +re-issued, via the serial sync path. ``Retry-After`` (when the +server sets it) is surfaced on the exception as ``.retry_after``. Dedup: list-axis chunks don't overlap; filter-axis chunks can, so ``_combine_chunk_frames`` dedupes by feature ``id``. ``properties``, @@ -27,11 +47,17 @@ from __future__ import annotations +import asyncio import copy import functools import itertools import math -from collections.abc import Callable, Iterator +import os +import random +import sys +import time +import warnings +from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager, suppress from contextvars import ContextVar from dataclasses import dataclass @@ -93,15 +119,172 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" +# Environment variable that controls async fan-out concurrency. Read +# at call time (not import) so test patches via ``monkeypatch.setenv`` +# take effect. The default (16) is the server-friendly sweet spot: +# higher values trip the upstream into 5xx burst-protection in +# practice. Set to ``1`` to force the serial sync path, set to +# ``unbounded`` for no per-call cap (use sparingly — you own the +# upstream-burst risk). +_CONCURRENCY_ENV = "API_USGS_CONCURRENT" +_CONCURRENCY_DEFAULT = 16 +_CONCURRENCY_UNBOUNDED = "unbounded" + + +def _read_concurrency_env() -> int | None: + """ + Resolve the ``API_USGS_CONCURRENT`` env var to a parallelism cap. + + Returns + ------- + int or None + ``1`` for the serial sync path; an integer >1 for bounded + parallelism; ``None`` to disable the per-call cap entirely + (``unbounded`` keyword). Unset → default of + ``_CONCURRENCY_DEFAULT``. + """ + raw = os.environ.get(_CONCURRENCY_ENV) + if raw is None: + return _CONCURRENCY_DEFAULT + raw = raw.strip() + if raw == "": + return _CONCURRENCY_DEFAULT + if raw.lower() == _CONCURRENCY_UNBOUNDED: + return None + try: + value = int(raw) + except ValueError as exc: + raise ValueError( + f"{_CONCURRENCY_ENV} must be a positive integer or " + f"'{_CONCURRENCY_UNBOUNDED}'; got {raw!r}." + ) from exc + if value < 1: + raise ValueError( + f"{_CONCURRENCY_ENV} must be >= 1 (got {value}); use " + f"'{_CONCURRENCY_UNBOUNDED}' to disable the cap." + ) + return value + + +# Retry-with-backoff for transient sub-request failures (429 / 5xx / +# connect-read timeouts). The env var is read at call time so test +# ``monkeypatch.setenv`` takes effect; the timing constants are +# module-level so power users (and tests) can ``monkeypatch.setattr`` +# them. Defaults: 4 retries, 0.5s base doubling under full jitter up to +# a 30s per-attempt ceiling, and honor a server ``Retry-After`` up to +# 60s before escalating to a resumable interruption instead. +_RETRIES_ENV = "API_USGS_RETRIES" +_RETRIES_DEFAULT = 4 +_RETRY_BASE_BACKOFF = 0.5 +_RETRY_MAX_BACKOFF = 30.0 +_RETRY_AFTER_CAP = 60.0 + + +def _read_retries_env() -> int: + """ + Resolve the ``API_USGS_RETRIES`` env var to a max-retry count. + + Returns + ------- + int + Number of retries after the first attempt; ``0`` disables + retrying. Unset/blank → ``_RETRIES_DEFAULT``. + """ + raw = os.environ.get(_RETRIES_ENV) + if raw is None or raw.strip() == "": + return _RETRIES_DEFAULT + try: + value = int(raw.strip()) + except ValueError as exc: + raise ValueError( + f"{_RETRIES_ENV} must be a non-negative integer (got {raw!r})." + ) from exc + if value < 0: + raise ValueError(f"{_RETRIES_ENV} must be >= 0 (got {value}).") + return value + + +@dataclass(frozen=True) +class RetryPolicy: + """Bounded retry-with-backoff config for transient sub-request failures. + + An immutable value object that owns the *timing* decisions; the + exception taxonomy ("is this worth retrying at all?") lives in + :func:`_retryable`. Backoff is exponential with **full jitter** + (:func:`random.uniform` over ``[0, ceiling]``) so the concurrent + fan-out's retries don't re-burst in lockstep. A server ``Retry-After`` + hint, when present, overrides the computed backoff — unless it exceeds + :attr:`retry_after_cap`, in which case retrying stops and the failure + surfaces as a resumable :class:`ChunkInterrupted` (a multi-minute + quota-window reset shouldn't block the call inline). + + Attributes + ---------- + max_retries : int + Retries attempted after the first try; ``0`` disables retrying. + base_backoff : float + Seconds; the jitter ceiling for the first retry, doubled each + subsequent attempt. + max_backoff : float + Upper bound on any single attempt's backoff ceiling. + retry_after_cap : float + Largest ``Retry-After`` (seconds) honored inline; longer hints + escalate to a resumable interruption. + """ + + max_retries: int = _RETRIES_DEFAULT + base_backoff: float = _RETRY_BASE_BACKOFF + max_backoff: float = _RETRY_MAX_BACKOFF + retry_after_cap: float = _RETRY_AFTER_CAP + + @classmethod + def from_env(cls) -> RetryPolicy: + """Build a policy, resolving ``max_retries`` from ``API_USGS_RETRIES``.""" + return cls(max_retries=_read_retries_env()) + + def should_retry(self, attempt: int, retry_after: float | None) -> bool: + """Whether a just-failed ``attempt`` (1-based) warrants another try. + + A ``Retry-After`` longer than ``retry_after_cap`` is *not* slept + off inline — it returns ``False`` so the failure escalates to a + resumable interruption instead of blocking the call for minutes. + """ + if attempt > self.max_retries: + return False + return retry_after is None or retry_after <= self.retry_after_cap + + def backoff(self, attempt: int, retry_after: float | None) -> float: + """Seconds to wait before retry ``attempt`` (1-based).""" + if retry_after is not None: + return retry_after + ceiling = min(self.max_backoff, self.base_backoff * 2 ** (attempt - 1)) + return random.uniform(0.0, ceiling) + + +# Default for direct ``ChunkedCall`` / ``ChunkPlan.execute`` construction +# (and tests): no retrying. The production decorator path explicitly passes +# ``RetryPolicy.from_env()`` so retries are on by default there. +_NO_RETRY = RetryPolicy(max_retries=0) + + # Client shared across all sub-requests of a single chunked call so # paginated-loop helpers downstream (``_walk_pages``) reuse one -# connection pool across the whole call. ``None`` when not inside a +# connection pool across the whole fan-out. ``None`` when not inside a # chunked call — paginated helpers fall back to their own short-lived # client in that case. _chunked_client: ContextVar[httpx.Client | None] = ContextVar( "_chunked_client", default=None ) +# Async sibling of ``_chunked_client``. Published by +# ``_publish_async_client`` during ``_fan_out_async`` so async +# paginated-loop helpers reuse one ``httpx.AsyncClient`` (and its +# connection pool) across every concurrent sub-request of a single +# chunked call. +_chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( + "_chunked_async_client", default=None +) + @contextmanager def _publish_client(client: httpx.Client) -> Iterator[None]: @@ -117,6 +300,20 @@ def _publish_client(client: httpx.Client) -> Iterator[None]: _chunked_client.reset(token) +@contextmanager +def _publish_async_client(client: httpx.AsyncClient) -> Iterator[None]: + """ + Make ``client`` visible to :func:`get_active_async_client` for the + duration of the ``with`` block. Async sibling of + :func:`_publish_client`. + """ + token = _chunked_async_client.set(client) + try: + yield + finally: + _chunked_async_client.reset(token) + + def get_active_client() -> httpx.Client | None: """ Return the chunker's currently-published sync client, or ``None``. @@ -134,6 +331,16 @@ def get_active_client() -> httpx.Client | None: return _chunked_client.get() +def get_active_async_client() -> httpx.AsyncClient | None: + """ + Return the chunker's currently-published async client, or ``None``. + + Async sibling of :func:`get_active_client`. Used by async + paginated-loop helpers to reuse the per-call AsyncClient pool. + """ + return _chunked_async_client.get() + + # Separators the two axis kinds use to join their atoms back into # URL text. List axes comma-join values (``site=USGS-A,USGS-B``); the # filter axis OR-joins clauses (``filter=a='1' OR a='2'``). @@ -141,6 +348,9 @@ def get_active_client() -> httpx.Client | None: _OR_SEP = " OR " _FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, httpx.Response]] +_FetchOnceAsync = Callable[ + [dict[str, Any]], Awaitable[tuple[pd.DataFrame, httpx.Response]] +] class _RetryableTransportError(RuntimeError): @@ -767,7 +977,9 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]: sub_args[axis.arg_key] = axis.render(chunk) yield sub_args - def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response]: + def execute( + self, fetch_once: _FetchOnce, retry_policy: RetryPolicy = _NO_RETRY + ) -> tuple[pd.DataFrame, httpx.Response]: """ Run the plan and return the combined ``(frame, response)``. @@ -779,6 +991,9 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response] fetch_once : Callable Function that issues a single sub-request, given the substituted args dict, and returns ``(frame, response)``. + retry_policy : RetryPolicy, optional + Per-sub-request retry-with-backoff policy. Defaults to + :data:`_NO_RETRY`; the decorator passes ``RetryPolicy.from_env()``. Returns ------- @@ -796,7 +1011,7 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response] :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call``. """ - return ChunkedCall(self, fetch_once).resume() + return ChunkedCall(self, fetch_once, retry_policy).resume() def _classify_chunk_error( @@ -850,6 +1065,93 @@ def _classify_chunk_error( return None +def _retryable(exc: BaseException) -> tuple[bool, float | None]: + """ + Decide whether ``exc`` is a transient worth an automatic retry. + + Narrower than :func:`_classify_chunk_error`: it retries rate limits + (429), service errors (5xx), and genuine transport transients + (:class:`httpx.TransportError` — ``ConnectError``, ``ReadTimeout``, …) + but NOT :class:`httpx.InvalidURL` (a too-long server cursor URL won't + fix on retry, though it stays *resumable*). Walks the ``__cause__`` + chain because ``_walk_pages`` re-wraps mid-pagination failures as + ``RuntimeError``. + + Returns + ------- + tuple[bool, float or None] + ``(retryable, retry_after)`` — the server ``Retry-After`` hint + (seconds) when the transient carried one, else ``None``. + """ + cur: BaseException | None = exc + while cur is not None: + if isinstance(cur, (RateLimited, ServiceUnavailable)): + return True, cur.retry_after + if isinstance(cur, httpx.TransportError): + return True, None + cur = cur.__cause__ + return False, None + + +# Sleep hooks, indirected through module globals so tests can +# ``monkeypatch.setattr`` them to no-ops instead of waiting for real +# backoff. Production uses the stdlib calls. +_SLEEP = time.sleep +_ASLEEP = asyncio.sleep + + +def _note_retry(attempt: int, wait: float) -> None: + """Surface an imminent retry on the active progress reporter, if any.""" + reporter = _progress.current() + if reporter is not None: + reporter.note_retry(attempt=attempt, wait=wait) + + +def _retry_sync( + fn: Callable[[], tuple[pd.DataFrame, httpx.Response]], + policy: RetryPolicy, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Call ``fn`` with bounded retry-with-backoff on transient failures. + + On a non-retryable error, or once ``policy`` is exhausted (or the + server's ``Retry-After`` is too long to absorb inline), the last + exception propagates unchanged so the caller's existing handling wraps + it as a resumable :class:`ChunkInterrupted`. + """ + attempt = 0 + while True: + try: + return fn() + except Exception as exc: # noqa: BLE001 — re-raised unless retryable + retryable, retry_after = _retryable(exc) + attempt += 1 + if not retryable or not policy.should_retry(attempt, retry_after): + raise + delay = policy.backoff(attempt, retry_after) + _note_retry(attempt, delay) + _SLEEP(delay) + + +async def _retry_async( + afn: Callable[[], Awaitable[tuple[pd.DataFrame, httpx.Response]]], + policy: RetryPolicy, +) -> tuple[pd.DataFrame, httpx.Response]: + """Async sibling of :func:`_retry_sync` (awaits :func:`asyncio.sleep`).""" + attempt = 0 + while True: + try: + return await afn() + except Exception as exc: # noqa: BLE001 — re-raised unless retryable + retryable, retry_after = _retryable(exc) + attempt += 1 + if not retryable or not policy.should_retry(attempt, retry_after): + raise + delay = policy.backoff(attempt, retry_after) + _note_retry(attempt, delay) + await _ASLEEP(delay) + + def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: """ Concatenate per-chunk frames, dropping empties and deduping by ``id``. @@ -989,9 +1291,11 @@ class ChunkedCall: :meth:`resume` is idempotent: it iterates :meth:`ChunkPlan.iter_sub_args` (deterministic order) and skips any index whose result is already in ``self._chunks``. The - completion set is a ``dict[int, (df, response)]`` keyed by - sub-args index; a subsequent ``resume`` only re-issues - sub-requests whose index isn't already present. + completion set is a sparse ``dict[int, (df, response)]`` so the + parallel path can record scattered completions (e.g. indices + [0, 2, 5] after siblings [1, 3, 4] failed) and a subsequent + ``resume`` only re-issues the missing indices — via the serial + sync ``fetch_once`` path. Parameters ---------- @@ -1015,13 +1319,53 @@ class ChunkedCall: when nothing has completed yet (live; recomputed per access). """ - def __init__(self, plan: ChunkPlan, fetch_once: _FetchOnce) -> None: + def __init__( + self, + plan: ChunkPlan, + fetch_once: _FetchOnce, + retry_policy: RetryPolicy = _NO_RETRY, + ) -> None: self.plan = plan self.fetch_once = fetch_once - # Completed (frame, response) pairs keyed by sub-args index; - # ``resume()`` skips indices already present. + self.retry_policy = retry_policy + # Completed (frame, response) pairs keyed by sub-args index. + # Sparse so the parallel fan-out path can record scattered + # completions (e.g. indices [0, 2, 5] when 1/3/4 failed) and a + # subsequent ``resume()`` only re-issues the missing indices. + # On the serial path this fills contiguously from 0. self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} + def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: + """Record a completed sub-request's ``(frame, response)`` pair + under its sub-args index. Used by both the serial loop in + :meth:`resume` and the parallel fan-out in + :func:`_fan_out_async` so the completion set stays + encapsulated.""" + self._chunks[index] = pair + + def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: + """Build the matching :class:`ChunkInterrupted` carrying this + call when ``exc`` is a recognized transient transport failure; + return ``None`` for unrecognized failures so the caller can + re-raise. Encapsulates the + ``classify → instantiate-with-call-state`` recipe so + :class:`ChunkedCall`'s private fields stay private.""" + classification = _classify_chunk_error(exc) + if classification is None: + return None + interrupted_class, retry_after = classification + return interrupted_class( + completed_chunks=len(self._chunks), + total_chunks=self.plan.total, + call=self, + retry_after=retry_after, + cause=exc, + ) + + @property + def completed_chunks(self) -> int: + return len(self._chunks) + def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: return [self._chunks[i] for i in sorted(self._chunks)] @@ -1078,7 +1422,9 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]: Idempotent: only sub-requests whose index isn't already in ``self._chunks`` are re-issued. Sub-args order matches - :meth:`ChunkPlan.iter_sub_args` and is deterministic. + :meth:`ChunkPlan.iter_sub_args` and is deterministic, so a + parallel-mode partial completion (sparse indices) resumes + correctly via the sync path. Returns ------- @@ -1132,24 +1478,148 @@ def _issue(self, index: int, sub_args: dict[str, Any]) -> None: three feed :func:`_classify_chunk_error`. """ try: - self._chunks[index] = self.fetch_once(sub_args) + chunk = _retry_sync(lambda: self.fetch_once(sub_args), self.retry_policy) except (RuntimeError, httpx.HTTPError, httpx.InvalidURL) as exc: - classification = _classify_chunk_error(exc) - if classification is None: + interrupted = self.wrap_failure(exc) + if interrupted is None: raise - interrupted_class, retry_after = classification - raise interrupted_class( - completed_chunks=len(self._chunks), - total_chunks=self.plan.total, - call=self, - retry_after=retry_after, - cause=exc, - ) from exc + raise interrupted from exc + self.record(index, chunk) + + +async def _fan_out_async( + plan: ChunkPlan, + fetch_once: _FetchOnce, + fetch_async: _FetchOnceAsync, + *, + max_concurrent: int | None, + retry_policy: RetryPolicy = _NO_RETRY, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Execute ``plan`` concurrently under one shared + :class:`httpx.AsyncClient`. + + The fan-out preserves the same resumability contract the serial + :class:`ChunkedCall` path provides: + + * **Resumable interruptions.** Sub-requests run under + ``asyncio.gather`` with ``return_exceptions=True`` so completed + sub-requests survive a sibling's transient failure. On a + recognized transient (:class:`RateLimited`, + :class:`ServiceUnavailable`) a :class:`ChunkInterrupted` + subclass is raised with ``.call`` set to a + :class:`ChunkedCall` carrying the sparse completed sub-args; + ``exc.call.resume()`` re-issues only the unfinished ones via + the sync ``fetch_once`` path. + + In-flight sub-requests are capped by an + :class:`asyncio.Semaphore`; ``max_concurrent=None`` ("unbounded") + uses ``sys.maxsize`` so every call site can take the same + ``async with semaphore`` path. The shared client is published on + :data:`_chunked_async_client` so async paginated-loop helpers + reuse its connection pool. + + Parameters + ---------- + plan : ChunkPlan + Pre-built plan whose sub-args sequence drives the fan-out. + fetch_once : Callable + Sync per-sub-request fetcher. Used to build the resumable + :class:`ChunkedCall` returned via ``ChunkInterrupted.call``; + never invoked by this function directly. + fetch_async : Callable + Async per-sub-request fetcher returning ``(df, response)``. + max_concurrent : int or None + Maximum in-flight sub-requests. ``None`` disables the cap. + + Returns + ------- + df : pandas.DataFrame + Combined data from every sub-request. + response : httpx.Response + Aggregated response (canonical URL, last sub-request's + headers, cumulative elapsed time). + + Raises + ------ + ChunkInterrupted + On a transient sub-request failure. ``.call`` is a + :class:`ChunkedCall` holding the sparse completed sub-requests; + ``.call.resume()`` re-issues the unfinished ones serially. + """ + sub_args_list = list(plan.iter_sub_args()) + + # ``httpx.Limits()`` defaults to ``max_connections=100`` — at + # higher concurrency the pool would silently bottleneck the + # fan-out behind the connection cap. Match it to the semaphore, + # or ``None`` for truly unbounded. + limits = httpx.Limits( + max_connections=max_concurrent, max_keepalive_connections=max_concurrent + ) + # ``sys.maxsize`` stands in for "unbounded": ``asyncio.Semaphore`` + # only decrements a counter, never preallocates slots. + semaphore = asyncio.Semaphore(max_concurrent or sys.maxsize) + call = ChunkedCall(plan, fetch_once, retry_policy) + + async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: + with _publish_async_client(client): + reporter = _progress.current() + if reporter is not None: + reporter.set_chunks(plan.total) + + async def track( + offset: int, args: dict[str, Any] + ) -> tuple[pd.DataFrame, httpx.Response]: + """One sub-request (with retry) + record + progress tick. + + The retry loop runs *inside* the semaphore, so a chunk + backing off holds its slot — effective concurrency shrinks + under throttling instead of re-bursting against it. + """ + async with semaphore: + result = await _retry_async(lambda: fetch_async(args), retry_policy) + call.record(offset, result) + if reporter is not None: + reporter.start_chunk(call.completed_chunks) + return result + + # Dispatch every sub-request concurrently. ``return_exceptions`` + # keeps completed pairs after a sibling fails, so partial state + # stays recoverable via ``ChunkedCall.resume()``. Failure + # precedence: + # 1. Cancellation / interrupt signals (CancelledError, + # KeyboardInterrupt, SystemExit — non-Exception) propagate + # unmodified; wrapping them as a transient would swallow the + # user's stop signal. + # 2. Recognized transients wrap as ChunkInterrupted so the user + # gets a resumable handle even when a non-transient failure + # landed earlier in submission order. + # 3. Otherwise re-raise the first failure, preserving its type. + results = await asyncio.gather( + *(track(i, args) for i, args in enumerate(sub_args_list)), + return_exceptions=True, + ) + failures = [r for r in results if isinstance(r, BaseException)] + for exc in failures: + if not isinstance(exc, Exception): + raise exc + for exc in failures: + if (interrupted := call.wrap_failure(exc)) is not None: + raise interrupted from exc + if failures: + raise failures[0] + + ordered = call._ordered_chunks() + return ( + _combine_chunk_frames([df for df, _ in ordered]), + _combine_chunk_responses([resp for _, resp in ordered], plan.canonical_url), + ) def multi_value_chunked( *, build_request: Callable[..., httpx.Request], + fetch_async: _FetchOnceAsync | None = None, url_limit: int | None = None, ) -> Callable[[_FetchOnce], _FetchOnce]: """ @@ -1161,12 +1631,24 @@ def multi_value_chunked( single-step plan, so the decorated function has one code path either way. + When ``API_USGS_CONCURRENT`` resolves to a parallelism greater than + 1 (the default), the decorator routes execution through + :func:`_fan_out_async` over the provided ``fetch_async``. The + wrapper falls back to the synchronous :class:`ChunkedCall` path + (with a ``UserWarning``) when ``fetch_async`` wasn't wired or + when an asyncio event loop is already running (Jupyter / IPython / + async apps where ``asyncio.run`` would raise ``RuntimeError``). + Parameters ---------- build_request : Callable[..., httpx.Request] Factory that turns a kwargs dict into a sized httpx request, e.g. ``_construct_api_requests``. Called during planning to measure each candidate plan. + fetch_async : Callable, optional + Async sibling of the decorated sync fetcher. Used when + ``API_USGS_CONCURRENT`` resolves to >1; if omitted, the + wrapper warns and stays on the serial path. url_limit : int, optional Byte budget for the request (URL + body). When ``None`` (default), the module-level ``_WATERDATA_URL_BYTE_LIMIT`` is @@ -1202,8 +1684,78 @@ def wrapper( ) -> tuple[pd.DataFrame, httpx.Response]: limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit plan = ChunkPlan(args, build_request, limit) - return plan.execute(fetch_once) + concurrency = _read_concurrency_env() + retry_policy = RetryPolicy.from_env() + + # Trivial plans and explicit opt-outs stay on the sync + # path; ``_execute_in_parallel`` owns the rest of the + # serial/parallel decision (async wiring, running loop). + if plan.total <= 1 or concurrency == 1: + return plan.execute(fetch_once, retry_policy) + return _execute_in_parallel( + plan, fetch_once, fetch_async, concurrency, retry_policy + ) return wrapper return decorator + + +def _execute_in_parallel( + plan: ChunkPlan, + fetch_once: _FetchOnce, + fetch_async: _FetchOnceAsync | None, + concurrency: int | None, + retry_policy: RetryPolicy = _NO_RETRY, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Run ``plan`` on the parallel async path, falling back to the + serial sync path when the runtime can't host an event loop. + + Falls back (with a one-time :class:`UserWarning`) when: + + * ``fetch_async`` wasn't wired into the decorator, or + * an asyncio event loop is already running (Jupyter / IPython + kernels, async apps — ``asyncio.run`` would raise). + + Otherwise opens a fresh event loop via :func:`asyncio.run` and + drives :func:`_fan_out_async`. + """ + if fetch_async is None: + warnings.warn( + f"{_CONCURRENCY_ENV} is set to {concurrency} but this " + f"call site has no async fetch sibling wired; falling " + f"back to the serial path. Either set " + f"{_CONCURRENCY_ENV}=1 to silence this warning or pass " + f"fetch_async= to @multi_value_chunked.", + UserWarning, + stacklevel=3, + ) + return plan.execute(fetch_once, retry_policy) + if _running_event_loop() is not None: + warnings.warn( + "Detected a running asyncio event loop; the parallel " + f"chunker path cannot run inside one. Falling back to " + f"the serial path. Set {_CONCURRENCY_ENV}=1 to silence " + f"this warning.", + UserWarning, + stacklevel=3, + ) + return plan.execute(fetch_once, retry_policy) + return asyncio.run( + _fan_out_async( + plan, + fetch_once, + fetch_async, + max_concurrent=concurrency, + retry_policy=retry_policy, + ) + ) + + +def _running_event_loop() -> asyncio.AbstractEventLoop | None: + """Return the active asyncio event loop, or ``None`` when none.""" + try: + return asyncio.get_running_loop() + except RuntimeError: + return None diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index 66ed1723..f8475957 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -7,12 +7,14 @@ import os import re from collections.abc import ( + AsyncIterator, + Awaitable, Callable, Iterable, Iterator, Mapping, ) -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from datetime import datetime, timedelta from typing import Any, TypeVar, get_args from zoneinfo import ZoneInfo @@ -28,6 +30,7 @@ RateLimited, ServiceUnavailable, _safe_elapsed, + get_active_async_client, get_active_client, ) from dataretrieval.waterdata.types import ( @@ -837,6 +840,29 @@ def _client_for(client: httpx.Client | None) -> Iterator[httpx.Client]: yield new +@asynccontextmanager +async def _client_for_async( + client: httpx.AsyncClient | None, +) -> AsyncIterator[httpx.AsyncClient]: + """ + Yield a usable async client, picking the best available source. + Async sibling of :func:`_client_for`. + + Resolution order matches the sync version: explicit caller-owned + ``AsyncClient`` first, the chunker's shared async client next, a + fresh short-lived ``AsyncClient`` last. + """ + if client is not None: + yield client + return + shared = get_active_async_client() + if shared is not None: + yield shared + return + async with httpx.AsyncClient(**HTTPX_DEFAULTS) as new: + yield new + + def _aggregate_paginated_response( initial: httpx.Response, last: httpx.Response, @@ -998,14 +1024,86 @@ def _paginate( return pd.concat(dfs, ignore_index=True), final_response +async def _paginate_async( + initial_req: httpx.Request, + *, + parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], + follow_up: Callable[[_Cursor, httpx.AsyncClient], Awaitable[httpx.Response]], + client: httpx.AsyncClient | None = None, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Drive a paginated request to completion over an + :class:`httpx.AsyncClient`. Async sibling of :func:`_paginate`. + + Runs the same per-page loop but issues HTTP asynchronously so + multiple sub-requests of a chunked call can run concurrently from + :func:`_fan_out_async`. + """ + logger.debug("Requesting: %s", initial_req.url) + reporter = _progress.current() + async with _client_for_async(client) as sess: + resp = await sess.send(initial_req) + _raise_for_non_200(resp) + initial_response = resp + total_elapsed = _safe_elapsed(resp) + + try: + df, cursor = parse_response(resp) + except Exception as e: # noqa: BLE001 + # Mirror the sync path: initial-page parse failures + # (malformed JSON, missing ``features``, schema drift) + # get the same wrapped-message treatment as follow-up + # failures so callers see a consistent diagnostic + # regardless of which page broke. + logger.warning("Initial response parse failed.") + raise RuntimeError(_paginated_failure_message(0, e)) from e + dfs = [df] + if reporter is not None: + reporter.set_rate_remaining( + resp.headers.get(_QUOTA_HEADER), + limit=resp.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(df)) + while cursor is not None: + try: + resp = await follow_up(cursor, sess) + _raise_for_non_200(resp) + df, cursor = parse_response(resp) + dfs.append(df) + total_elapsed += _safe_elapsed(resp) + if reporter is not None: + reporter.set_rate_remaining( + resp.headers.get(_QUOTA_HEADER), + limit=resp.headers.get("x-ratelimit-limit"), + ) + reporter.add_page(rows=len(df)) + except Exception as e: # noqa: BLE001 + logger.warning( + "Request failed at cursor %r. Data download interrupted.", + cursor, + ) + raise RuntimeError(_paginated_failure_message(len(dfs), e)) from e + + # Aggregate headers / elapsed onto a COPY of the initial + # response so the user's caller never sees an in-place + # mutation of the response object they may have inspected + # mid-pagination via a hook or test fixture. + final_response = _aggregate_paginated_response( + initial_response, resp, total_elapsed + ) + return pd.concat(dfs, ignore_index=True), final_response + + def _ogc_parse_response( resp: httpx.Response, *, geopd: bool ) -> tuple[pd.DataFrame, str | None]: """Parse one OGC API page: extract the DataFrame and the next-page URL. - Coerces falsy cursors (empty href, etc.) to ``None`` so the - paginate loop's ``while cursor is not None`` terminates instead - of spinning on a meaningless value. + Shared between :func:`_walk_pages` and :func:`_walk_pages_async` + since the parse step is identical on either path. Coerces falsy + cursors (empty href, etc.) to ``None`` so the paginate loop's + ``while cursor is not None`` terminates instead of spinning on a + meaningless value. """ body = resp.json() return ( @@ -1069,6 +1167,31 @@ def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: ) +async def _walk_pages_async( + geopd: bool, + req: httpx.Request, + client: httpx.AsyncClient | None = None, +) -> tuple[pd.DataFrame, httpx.Response]: + """ + Iterate paginated OGC API responses asynchronously and aggregate + them into one DataFrame. Async sibling of :func:`_walk_pages`; + delegates to :func:`_paginate_async`. + """ + method = req.method + headers = req.headers + content = req.content if method == "POST" else None + + async def follow_up(cursor: str, sess: httpx.AsyncClient) -> httpx.Response: + return await sess.request(method, cursor, headers=headers, content=content) + + return await _paginate_async( + req, + parse_response=functools.partial(_ogc_parse_response, geopd=geopd), + follow_up=follow_up, + client=client, + ) + + def _deal_with_empty( return_list: pd.DataFrame, properties: list[str] | None, service: str ) -> pd.DataFrame: @@ -1290,8 +1413,19 @@ def get_ogc_data( return return_list, BaseMetadata(response) +async def _fetch_once_async( + args: dict[str, Any], +) -> tuple[pd.DataFrame, httpx.Response]: + """Send one prepared-args OGC request asynchronously; return the + frame + response. Async sibling of :func:`_fetch_once` used by the + parallel chunker.""" + req = _construct_api_requests(**args) + return await _walk_pages_async(geopd=GEOPANDAS, req=req) + + @chunking.multi_value_chunked( build_request=_construct_api_requests, + fetch_async=_fetch_once_async, ) def _fetch_once( args: dict[str, Any], @@ -1302,7 +1436,10 @@ def _fetch_once( parameter and the cql-text filter as a chunkable axis, greedy-halves the biggest chunk across all axes until each sub-request URL fits, and iterates the cartesian product. With no chunkable inputs the - decorator passes args through unchanged. The return shape + decorator passes args through unchanged. When ``API_USGS_CONCURRENT`` + is >1 (the default), the decorator routes execution through + :func:`_fetch_once_async` so the sub-requests run concurrently under + one shared :class:`httpx.AsyncClient`. Either way the return shape is ``(frame, response)``. """ req = _construct_api_requests(**args) diff --git a/tests/conftest.py b/tests/conftest.py index afbdfec2..5eb46cb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,15 @@ """ Test scaffolding for the dataretrieval test suite. -Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and -unmatched real requests don't fail the suite (matches the historical -``requests-mock``-style permissiveness the test code was written -against, and keeps mocked-URL setup terse). +* Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and + unmatched real requests don't fail the suite (matches the historical + ``requests-mock``-style permissiveness the test code was written + against, and keeps mocked-URL setup terse). +* Pins ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` for every + test by default so the historical mocked suite stays on the + deterministic serial chunker path and a single transient surfaces + immediately (no backoff). Async-mode and retry tests opt in by + re-setting the env vars inside their body via ``monkeypatch.setenv``. """ from __future__ import annotations @@ -30,3 +35,21 @@ def non_mocked_hosts() -> list[str]: """No hosts are exempted from mocking; every HTTP call must hit a mock registered through the ``httpx_mock`` fixture.""" return [] + + +@pytest.fixture(autouse=True) +def _serial_chunker(monkeypatch): + """Default every test to the serial, no-retry chunker path. + + Production defaults ``API_USGS_CONCURRENT`` to 16 (parallel + fan-out) and ``API_USGS_RETRIES`` to 4, but the historical tests + assume sequential, deterministic sub-request ordering — and they + mock the sync ``_walk_pages`` rather than the async sibling, and + expect a single transient to surface immediately rather than be + retried. Pinning ``API_USGS_CONCURRENT=1`` and ``API_USGS_RETRIES=0`` + keeps the test surface focused on the planner / fetch contracts; + async-mode and retry tests opt in by overriding the env inside + their body. + """ + monkeypatch.setenv("API_USGS_CONCURRENT", "1") + monkeypatch.setenv("API_USGS_RETRIES", "0") diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index 21b23757..ee129aaa 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -15,11 +15,13 @@ and then fail in production. """ +import asyncio import datetime import sys from unittest import mock from urllib.parse import quote_plus +import httpx import pandas as pd import pytest @@ -36,10 +38,14 @@ QuotaExhausted, RateLimited, RequestTooLarge, + RetryPolicy, ServiceInterrupted, ServiceUnavailable, _chunked_client, _extract_axes, + _retry_async, + _retry_sync, + _retryable, multi_value_chunked, ) from dataretrieval.waterdata.utils import _construct_api_requests @@ -1202,6 +1208,195 @@ def test_iter_sub_args_passthrough_yields_a_copy(): assert "new_key" not in plan.args +# --- async fan-out path ---------------------------------------------------- +# +# The conftest's ``_serial_chunker`` autouse pins ``API_USGS_CONCURRENT=1`` +# for the whole suite. Each test below overrides it so the wrapper takes +# the parallel branch. The decorator's ``fetch_async`` accepts any +# coroutine returning ``(df, response)`` — no real ``httpx.AsyncClient`` +# round-trip occurs, even though :func:`_fan_out_async` opens one for +# pool management. + + +def _async_chunked_fetch(monkeypatch, fetch_async, *, max_concurrent=16): + """Decorate a deterministic chunkable fetch with the parallel + path forced on via ``API_USGS_CONCURRENT``.""" + monkeypatch.setenv("API_USGS_CONCURRENT", str(max_concurrent)) + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + # Sync sibling — invoked on resume() after a parallel failure + # and never during the happy parallel path. + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + return fetch + + +def _atom_id(args): + """Build a deterministic id for a sub-args dict — used as the dedup key.""" + return ",".join(args["sites"]) if isinstance(args["sites"], list) else args["sites"] + + +def _ok_response(remaining=None): + headers = {} if remaining is None else {_QUOTA_HEADER: str(remaining)} + return mock.Mock(elapsed=datetime.timedelta(seconds=0.1), headers=headers) + + +def test_async_fan_out_emits_one_call_per_sub_request(monkeypatch): + """Parallel mode hits every sub-args once — same coverage as the + sync ``ChunkedCall`` path, just dispatched concurrently.""" + seen_args = [] + + async def fetch_async(args): + seen_args.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + # Planner halves the 4-site list, so 2 sub-args → 2 async calls. + assert len(seen_args) > 1 + # Every sub-args atom is union-recovered. + assert sorted({s for tup in seen_args for s in tup}) == sorted( + ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + ) + # Frames concat to one row per sub-request id, in deterministic order. + assert len(df) == len(seen_args) + + +def test_async_fan_out_failure_yields_resumable_call(monkeypatch): + """A transient 5xx mid-fan-out raises ``ServiceInterrupted`` whose + ``.call`` is a ``ChunkedCall`` holding the completed sub-requests + in a sparse index map. ``exc.call.resume()`` re-issues only the + unfinished sub-requests, via the sync ``fetch_once`` path.""" + call_count = {"async": 0, "sync": 0} + + async def fetch_async(args): + call_count["async"] += 1 + # First sub-request succeeds; siblings fail. + if call_count["async"] == 1: + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + raise ServiceUnavailable("503: simulated") + + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + call_count["sync"] += 1 + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + with pytest.raises(ServiceInterrupted) as exc_info: + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + interrupted = exc_info.value + assert interrupted.call is not None, "parallel-mode interruption must be resumable" + # First sub-request completed; the rest still owe. + assert interrupted.completed_chunks == 1 + assert interrupted.total_chunks > 1 + + # Resume on the sync path picks up only the missing sub-requests. + sync_before = call_count["sync"] + df, _ = interrupted.call.resume() + sync_calls_on_resume = call_count["sync"] - sync_before + assert sync_calls_on_resume == interrupted.total_chunks - 1 + # Final frame unions every sub-args. + assert len(df) == interrupted.total_chunks + + +@pytest.mark.parametrize( + "fallback_trigger,warning_match", + [ + pytest.param("running_loop", "running asyncio event loop", id="running-loop"), + pytest.param("no_fetch_async", "no async fetch sibling", id="missing-async"), + ], +) +def test_async_falls_back_to_serial_with_warning( + monkeypatch, fallback_trigger, warning_match +): + """The parallel path falls back to the serial ``ChunkedCall`` + (with a ``UserWarning``) in two situations: + + * a running asyncio event loop (Jupyter / IPython kernels, async + apps) — ``asyncio.run`` would otherwise raise ``RuntimeError``; + * the decorator wasn't wired with a ``fetch_async=`` sibling — + ``API_USGS_CONCURRENT`` would otherwise be a silent no-op. + """ + sync_calls = [] + monkeypatch.setenv("API_USGS_CONCURRENT", "16") + + if fallback_trigger == "running_loop": + + async def fetch_async(args): + raise AssertionError("parallel path must not run inside an active loop") + + @multi_value_chunked( + build_request=_fake_build, fetch_async=fetch_async, url_limit=240 + ) + def fetch(args): + sync_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + async def driver(): + with pytest.warns(UserWarning, match=warning_match): + return fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + df, _ = asyncio.run(driver()) + else: + + @multi_value_chunked(build_request=_fake_build, url_limit=240) + def fetch(args): + sync_calls.append(tuple(args["sites"])) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + with pytest.warns(UserWarning, match=warning_match): + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + + assert len(sync_calls) > 1 + assert len(df) == len(sync_calls) + + +def test_async_fan_out_cancellation_wins_over_transient_sibling(monkeypatch): + """``asyncio.CancelledError`` raised by any sub-request must + propagate unmodified, even when a sibling raises a recognized + transient (which would otherwise wrap as a resumable + :class:`ChunkInterrupted`). Cancellation is asyncio's abort + signal — letting a transient-classification path consume it + would silently swallow the user's stop request. + + fetch_async has no ``await`` inside its body, so gather schedules + the tasks in submission order and each runs synchronously to its + raise — making ``call_count`` deterministic for this test: + 1 = probe, 2 = first fan-out task (transient), 3 = second + fan-out task (cancellation). + """ + call_count = {"async": 0} + + async def fetch_async(args): + call_count["async"] += 1 + if call_count["async"] == 1: + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + if call_count["async"] == 2: + raise ServiceUnavailable("503: transient sibling") + if call_count["async"] == 3: + raise asyncio.CancelledError("user cancel") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=99) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + + # 8 × 20-byte sites force the planner to >=3 sub-args under + # url_limit=240, so the fan-out gather sees at least the + # transient (call 2) AND the cancellation (call 3). + sites = [f"S{i}" * 10 for i in range(1, 9)] + + with pytest.raises(asyncio.CancelledError): + fetch({"sites": sites}) + + def test_combine_chunk_responses_does_not_mutate_input_urls(): """Regression for the _set_response_url aliasing bug. @@ -1230,3 +1425,228 @@ def test_combine_chunk_responses_does_not_mutate_input_urls(): assert str(r2.url) == "https://example.com/chunk1" assert str(req1.url) == "https://example.com/chunk0" assert str(req2.url) == "https://example.com/chunk1" + + +# --------------------------------------------------------------------------- +# Retry-with-backoff: RetryPolicy + _retryable + drivers + decorator wiring. +# Conftest pins API_USGS_RETRIES=0, so these tests opt in explicitly and +# patch chunking._SLEEP / chunking._ASLEEP to no-ops (no real backoff). +# --------------------------------------------------------------------------- + + +def _wrap_cause(transport_exc): + """Wrap ``transport_exc`` the way ``_walk_pages`` does — a + ``RuntimeError`` with the typed transport error on ``__cause__`` — so + chain-walking is exercised realistically.""" + try: + raise RuntimeError("Paginated request failed") from transport_exc + except RuntimeError as wrapped: + return wrapped + + +# -- RetryPolicy (pure value object) ---------------------------------------- + + +def test_retry_policy_backoff_honors_retry_after(): + policy = RetryPolicy() + # A server Retry-After overrides the computed backoff verbatim. + assert policy.backoff(attempt=1, retry_after=7.5) == 7.5 + assert policy.backoff(attempt=4, retry_after=2.0) == 2.0 + + +def test_retry_policy_backoff_full_jitter_within_ceiling(): + policy = RetryPolicy(base_backoff=2.0, max_backoff=30.0) + for attempt, ceiling in [(1, 2.0), (2, 4.0), (3, 8.0), (5, 30.0)]: + samples = [policy.backoff(attempt, None) for _ in range(200)] + assert all(0.0 <= s <= ceiling for s in samples) + # Full jitter genuinely varies and reaches below the ceiling. + assert min(samples) < ceiling + + +def test_retry_policy_should_retry_exhaustion(): + policy = RetryPolicy(max_retries=2) + assert policy.should_retry(attempt=1, retry_after=None) + assert policy.should_retry(attempt=2, retry_after=None) + assert not policy.should_retry(attempt=3, retry_after=None) + + +def test_retry_policy_long_retry_after_escalates(): + policy = RetryPolicy(max_retries=5, retry_after_cap=60.0) + assert policy.should_retry(attempt=1, retry_after=30.0) # absorbed inline + assert not policy.should_retry(attempt=1, retry_after=120.0) # escalates + + +def test_retry_policy_from_env(monkeypatch): + monkeypatch.setenv("API_USGS_RETRIES", "2") + assert RetryPolicy.from_env().max_retries == 2 + monkeypatch.setenv("API_USGS_RETRIES", "0") + assert RetryPolicy.from_env().max_retries == 0 + monkeypatch.delenv("API_USGS_RETRIES", raising=False) + assert RetryPolicy.from_env().max_retries == _chunking._RETRIES_DEFAULT + monkeypatch.setenv("API_USGS_RETRIES", "-1") + with pytest.raises(ValueError): + RetryPolicy.from_env() + monkeypatch.setenv("API_USGS_RETRIES", "lots") + with pytest.raises(ValueError): + RetryPolicy.from_env() + + +# -- _retryable taxonomy ---------------------------------------------------- + + +def test_retryable_taxonomy(): + assert _retryable(RateLimited("429", retry_after=5.0)) == (True, 5.0) + assert _retryable(ServiceUnavailable("503")) == (True, None) + assert _retryable(httpx.ReadTimeout("slow")) == (True, None) + assert _retryable(httpx.ConnectError("down")) == (True, None) + # InvalidURL is resumable but NOT retryable (a too-long cursor won't fix). + assert _retryable(httpx.InvalidURL("too long")) == (False, None) + # Plain non-transient (e.g. a 4xx programmer error wrapped as RuntimeError). + assert _retryable(RuntimeError("400")) == (False, None) + + +def test_retryable_walks_cause_chain(): + assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (True, 3.0) + + +# -- sync driver ------------------------------------------------------------ + + +def test_retry_sync_transient_then_success(monkeypatch): + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + if calls["n"] <= 2: + raise RateLimited("429") + return "ok" + + assert _retry_sync(fn, RetryPolicy(max_retries=3, base_backoff=0.0)) == "ok" + assert calls["n"] == 3 # two failures + one success + + +def test_retry_sync_exhausted_reraises(monkeypatch): + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + raise ServiceUnavailable("503") + + with pytest.raises(ServiceUnavailable): + _retry_sync(fn, RetryPolicy(max_retries=2, base_backoff=0.0)) + assert calls["n"] == 3 # first attempt + 2 retries + + +def test_retry_sync_non_retryable_not_retried(monkeypatch): + slept: list[float] = [] + monkeypatch.setattr(_chunking, "_SLEEP", slept.append) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + raise RuntimeError("400: bad request") + + with pytest.raises(RuntimeError): + _retry_sync(fn, RetryPolicy(max_retries=3)) + assert calls["n"] == 1 and slept == [] + + +def test_retry_sync_long_retry_after_escalates(monkeypatch): + slept: list[float] = [] + monkeypatch.setattr(_chunking, "_SLEEP", slept.append) + calls = {"n": 0} + + def fn(): + calls["n"] += 1 + raise RateLimited("429", retry_after=999.0) + + with pytest.raises(RateLimited): + _retry_sync(fn, RetryPolicy(max_retries=5, retry_after_cap=60.0)) + assert calls["n"] == 1 and slept == [] # no inline wait for a long window + + +# -- async driver ----------------------------------------------------------- + + +def test_retry_async_transient_then_success(monkeypatch): + async def _noslept(_d): + return None + + monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + calls = {"n": 0} + + async def afn(): + calls["n"] += 1 + if calls["n"] == 1: + raise httpx.ReadTimeout("slow") + return "ok" + + out = asyncio.run(_retry_async(afn, RetryPolicy(max_retries=3, base_backoff=0.0))) + assert out == "ok" and calls["n"] == 2 + + +# -- end-to-end through the decorator -------------------------------------- + + +def test_chunker_retries_transient_then_completes(monkeypatch): + """A transient on one sub-request is retried transparently; the + decorated call completes with no ChunkInterrupted.""" + monkeypatch.setenv("API_USGS_RETRIES", "3") + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + state = {"failed": False} + + def fetch(args): + # Fail the first sub-request once, then succeed everywhere. + if not state["failed"]: + state["failed"] = True + raise RateLimited("429: Too many requests made.") + return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + sites = ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10] + df, _ = decorated({"sites": sites}) + assert sorted(df["sites"]) == sorted(sites) # all recovered despite the 429 + + +def test_chunker_exhausted_retries_still_resumable(monkeypatch): + """When retries are exhausted the failure still surfaces as a + resumable ChunkInterrupted — retries don't swallow the escape hatch.""" + monkeypatch.setenv("API_USGS_RETRIES", "2") + monkeypatch.setattr(_chunking, "_SLEEP", lambda _d: None) + attempts = {"n": 0} + + def fetch(args): + sites = list(args["sites"]) + if "S1" * 10 in sites: + attempts["n"] += 1 + raise ServiceUnavailable("503: service unavailable") + return pd.DataFrame({"sites": sites}), _quota_response(500) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert excinfo.value.call is not None + assert attempts["n"] == 3 # first attempt + 2 retries before giving up + + +def test_async_fan_out_retries_transient_then_completes(monkeypatch): + """The parallel path retries a transient sub-request and completes.""" + monkeypatch.setenv("API_USGS_RETRIES", "3") + + async def _noslept(_d): + return None + + monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + state = {"failed": False} + + async def fetch_async(args): + if not state["failed"]: + state["failed"] = True + raise ServiceUnavailable("503: transient") + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response() + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert len(df) > 1 # every sub-args atom recovered after the retry diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index faa61630..30be56a2 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -65,6 +65,26 @@ def test_page_count_is_pluralized(): assert "2 pages" in stream.getvalue() +def test_note_retry_renders_then_clears_on_next_page(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.set_chunks(3) + reporter.start_chunk(1) + reporter.note_retry(attempt=2, wait=8.0) + assert "retrying (attempt 2, waiting 8s)" in stream.getvalue() + # The next page redraws without the note (last frame is after the + # final carriage return). + reporter.add_page(rows=5) + assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] + + +def test_note_retry_is_noop_when_disabled(): + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=False) + reporter.note_retry(attempt=1, wait=1.0) + assert stream.getvalue() == "" + + def test_chunk_segment_only_shown_when_multiple_chunks(): single = io.StringIO() reporter = ProgressReporter(stream=single, enabled=True) @@ -363,3 +383,124 @@ def test_broken_progress_stream_does_not_truncate_pagination(): df, _ = _walk_pages(geopd=False, req=req, client=client) assert len(df) == 2 # both pages returned despite the broken progress stream + + +# -- async path integration ---------------------------------------------------- + + +def test_paginate_async_reports_pages_through_active_reporter(monkeypatch): + """The async paginate path must drive the same progress reporter the + sync path does. Pages and rate-limit updates from each completed + page should land via the active ``ProgressReporter``, exactly as + they would on ``_walk_pages``.""" + import asyncio + + from dataretrieval.waterdata.utils import _paginate_async + + resp1 = _resp( + [{"id": "1", "properties": {"v": "a"}}], + next_url="https://example.com/p2", + rate_remaining="4999", + ) + resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") + + async def parse_response(resp): + body = resp.json() + nxt = next( + (link["href"] for link in body["links"] if link["rel"] == "next"), None + ) + return mock.MagicMock(empty=False, __len__=lambda self: 1), nxt + + # _paginate_async expects parse_response to be sync, like the sync path. + def parse_sync(resp): + body = resp.json() + nxt = next( + (link["href"] for link in body["links"] if link["rel"] == "next"), None + ) + import pandas as pd + + return pd.DataFrame(body["features"]), nxt + + async def follow_up(cursor, sess): + return resp2 + + client = mock.AsyncMock(spec=httpx.AsyncClient) + client.send.return_value = resp1 + + req = mock.MagicMock(spec=httpx.Request) + req.method = "GET" + req.headers = {} + req.url = "https://example.com/p1" + + stream = io.StringIO() + + async def run(): + with progress_context(service="continuous", stream=stream, enabled=True): + df, _ = await _paginate_async( + req, + parse_response=parse_sync, + follow_up=follow_up, + client=client, + ) + return df + + df = asyncio.run(run()) + assert len(df) == 2 + out = stream.getvalue() + assert "Retrieving: continuous ·" in out + assert "2 pages" in out + assert "4,998 requests remaining" in out + assert out.endswith("\n") + + +def test_fan_out_async_sets_chunks_on_active_reporter(monkeypatch): + """``_fan_out_async`` records ``plan.total`` on the active reporter + so the progress line knows how many sub-requests are in flight. + It deliberately does NOT call ``start_chunk`` (which would be + misleading under parallel fan-out — chunks fire concurrently).""" + import asyncio + + import pandas as pd + + from dataretrieval.waterdata.chunking import ChunkPlan, _fan_out_async + + # Fake build_request whose URL length scales with the sites list, + # mirroring the planner's _request_bytes contract. _FakeReq has the + # same shape as httpx.Request for sizing purposes. + class _FakeReq: + __slots__ = ("url", "content") + + def __init__(self, url): + self.url = url + self.content = b"" + + def build(*, sites): + return _FakeReq("x" * (200 + len(",".join(sites)))) + + sites = ["S" * 10 + str(i) for i in range(4)] + plan = ChunkPlan({"sites": sites}, build, url_limit=240) + assert plan.total > 1, "test setup error: plan must fan out" + + async def fetch_async(args): + return pd.DataFrame({"id": [",".join(args["sites"])]}), mock.Mock( + elapsed=__import__("datetime").timedelta(seconds=0.01), + headers={"x-ratelimit-remaining": "999"}, + ) + + def fetch_once(args): # noqa: ARG001 — never invoked on the happy parallel path + raise AssertionError("sync fetch must not run in this test") + + stream = io.StringIO() + + async def run(): + with progress_context(service="daily", stream=stream, enabled=True) as rep: + await _fan_out_async(plan, fetch_once, fetch_async, max_concurrent=4) + return rep.total_chunks, rep.current_chunk + + total_recorded, current_recorded = asyncio.run(run()) + assert total_recorded == plan.total + # Each sub-request that completes bumps current_chunk via + # start_chunk(len(completed)), so by the time the gather finishes + # current_chunk reflects the total number of successful chunks — + # plan.total in the all-success case. + assert current_recorded == plan.total diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index bb5ece10..413f39c8 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -221,6 +221,37 @@ def test_walk_pages_wraps_initial_page_parse_error(): assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) +def test_walk_pages_async_wraps_initial_page_parse_error(): + """Async sibling of the above. ``_paginate_async`` must wrap an + initial-page parse failure with the same ``RuntimeError`` shape so + callers get a consistent diagnostic across sync and async paths.""" + import asyncio + + from dataretrieval.waterdata.utils import _walk_pages_async + + resp = mock.MagicMock() + resp.status_code = 200 + resp.url = "https://example.com/page1" + resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) + + mock_client = mock.AsyncMock(spec=httpx.AsyncClient) + mock_client.send.return_value = resp + + mock_req = mock.MagicMock(spec=httpx.Request) + mock_req.method = "GET" + mock_req.headers = {} + mock_req.content = b"" + mock_req.url = "https://example.com/page1" + + async def run(): + await _walk_pages_async(geopd=False, req=mock_req, client=mock_client) + + with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: + asyncio.run(run()) + + assert isinstance(excinfo.value.__cause__, json.JSONDecodeError) + + def test_get_resp_data_handles_missing_features_key(): """Regression: a 200 with ``numberReturned > 0`` but no ``features`` key (real schema-drift shape) used to crash From 1d0acb7912cf7ffdeb2e5be884d1c619b0b2e9fe Mon Sep 17 00:00:00 2001 From: thodson-usgs Date: Tue, 26 May 2026 21:14:41 -0500 Subject: [PATCH 2/2] refactor(waterdata): maintainer polish for chunk retries --- dataretrieval/waterdata/_progress.py | 20 ++- dataretrieval/waterdata/chunking.py | 218 ++++++++++++++++----------- tests/waterdata_chunking_test.py | 68 ++++++++- tests/waterdata_progress_test.py | 20 +++ 4 files changed, 235 insertions(+), 91 deletions(-) diff --git a/dataretrieval/waterdata/_progress.py b/dataretrieval/waterdata/_progress.py index 7104f3af..e529d6d3 100644 --- a/dataretrieval/waterdata/_progress.py +++ b/dataretrieval/waterdata/_progress.py @@ -157,10 +157,17 @@ def add_page(self, rows: int = 0) -> None: def note_retry(self, *, attempt: int, wait: float) -> None: """Show that a sub-request is backing off before retry ``attempt``. - Cleared by the next :meth:`add_page` / :meth:`start_chunk` so the - line returns to normal progress once the retry succeeds. + Cleared by the next :meth:`add_page` / :meth:`start_chunk` (or by + :meth:`close`) so the line returns to normal once the retry resolves. """ - self.retry_note = f"retrying (attempt {attempt}, waiting {wait:.0f}s)" + # Keep sub-second waits explicit (avoid misleading ``0s``) while + # rendering whole-second waits without unnecessary ``.0`` noise. + wait_1dp = round(wait, 1) + if wait_1dp < 1 or not wait_1dp.is_integer(): + secs = f"{wait_1dp:.1f}s" + else: + secs = f"{wait_1dp:.0f}s" + self.retry_note = f"retrying (attempt {attempt}, waiting {secs})" self._render() def set_rate_remaining( @@ -225,6 +232,13 @@ def close(self) -> None: """ if self._closed: return + # A retry note set during the final backoff would otherwise freeze as + # the persisted last line of a call that has since completed or given + # up; clear it and redraw (while still un-closed, so ``_render`` runs) + # so the final state isn't a stale "retrying". + if self.enabled and self._rendered and self.retry_note is not None: + self.retry_note = None + self._render() self._closed = True if not (self.enabled and self._rendered): return diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index 1e3b429d..2d01614c 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -62,7 +62,7 @@ from contextvars import ContextVar from dataclasses import dataclass from datetime import timedelta -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeVar from urllib.parse import quote_plus import httpx @@ -166,13 +166,14 @@ def _read_concurrency_env() -> int | None: return value -# Retry-with-backoff for transient sub-request failures (429 / 5xx / -# connect-read timeouts). The env var is read at call time so test -# ``monkeypatch.setenv`` takes effect; the timing constants are -# module-level so power users (and tests) can ``monkeypatch.setattr`` -# them. Defaults: 4 retries, 0.5s base doubling under full jitter up to -# a 30s per-attempt ceiling, and honor a server ``Retry-After`` up to -# 60s before escalating to a resumable interruption instead. +# Retry-with-backoff defaults for transient sub-request failures (429 / +# 5xx / connect-read timeouts). All four are resolved at call time by +# ``RetryPolicy.from_env`` (the env var via ``monkeypatch.setenv``, the +# timing constants via ``monkeypatch.setattr`` on this module), so both +# are overridable in tests and by power users. Defaults: 4 retries, 0.5s +# base doubling under full jitter up to a 30s per-attempt ceiling, and +# honor a server ``Retry-After`` up to 60s before escalating to a +# resumable interruption instead. _RETRIES_ENV = "API_USGS_RETRIES" _RETRIES_DEFAULT = 4 _RETRY_BASE_BACKOFF = 0.5 @@ -237,10 +238,31 @@ class RetryPolicy: max_backoff: float = _RETRY_MAX_BACKOFF retry_after_cap: float = _RETRY_AFTER_CAP + def __post_init__(self) -> None: + # Guard the value object's own invariants so a misconfiguration + # fails loudly at construction rather than as a downstream + # ``time.sleep`` ValueError (negative delay) or a silent + # asyncio.sleep-treats-negative-as-zero divergence. + if self.max_retries < 0: + raise ValueError(f"max_retries must be >= 0 (got {self.max_retries}).") + if self.base_backoff < 0 or self.max_backoff < 0 or self.retry_after_cap < 0: + raise ValueError("retry backoff settings must be non-negative.") + @classmethod def from_env(cls) -> RetryPolicy: - """Build a policy, resolving ``max_retries`` from ``API_USGS_RETRIES``.""" - return cls(max_retries=_read_retries_env()) + """Build a policy from the module-level defaults, resolved now. + + ``max_retries`` comes from ``API_USGS_RETRIES``; the timing knobs + are read from the ``_RETRY_*`` module constants at call time (not + the dataclass field defaults, which freeze at class definition) so + ``monkeypatch.setattr`` on those constants takes effect. + """ + return cls( + max_retries=_read_retries_env(), + base_backoff=_RETRY_BASE_BACKOFF, + max_backoff=_RETRY_MAX_BACKOFF, + retry_after_cap=_RETRY_AFTER_CAP, + ) def should_retry(self, attempt: int, retry_after: float | None) -> bool: """Whether a just-failed ``attempt`` (1-based) warrants another try. @@ -276,42 +298,36 @@ def backoff(self, attempt: int, retry_after: float | None) -> float: "_chunked_client", default=None ) -# Async sibling of ``_chunked_client``. Published by -# ``_publish_async_client`` during ``_fan_out_async`` so async -# paginated-loop helpers reuse one ``httpx.AsyncClient`` (and its -# connection pool) across every concurrent sub-request of a single -# chunked call. +# Async sibling of ``_chunked_client``. Published (via :func:`_publish`) +# during ``_fan_out_async`` so async paginated-loop helpers reuse one +# ``httpx.AsyncClient`` (and its connection pool) across every concurrent +# sub-request of a single chunked call. _chunked_async_client: ContextVar[httpx.AsyncClient | None] = ContextVar( "_chunked_async_client", default=None ) - -@contextmanager -def _publish_client(client: httpx.Client) -> Iterator[None]: - """ - Make ``client`` visible to :func:`get_active_client` for the - duration of the ``with`` block via the ``_chunked_client`` - ContextVar. Wraps the set/reset token dance so callers don't have to. - """ - token = _chunked_client.set(client) - try: - yield - finally: - _chunked_client.reset(token) +_ClientT = TypeVar("_ClientT") @contextmanager -def _publish_async_client(client: httpx.AsyncClient) -> Iterator[None]: +def _publish(var: ContextVar[_ClientT | None], client: _ClientT) -> Iterator[None]: """ - Make ``client`` visible to :func:`get_active_async_client` for the - duration of the ``with`` block. Async sibling of - :func:`_publish_client`. + Bind ``client`` to the ContextVar ``var`` for the duration of the + ``with`` block (wrapping the set/reset token dance), so paginated-loop + helpers can borrow the chunker's shared client via + :func:`get_active_client` / :func:`get_active_async_client`. + + Generic over the client type so the sync (:class:`httpx.Client` via + ``_chunked_client``) and async (:class:`httpx.AsyncClient` via + ``_chunked_async_client``) paths share one implementation, while the + ``_ClientT`` type var still lets a type checker reject a var/client + type mismatch. """ - token = _chunked_async_client.set(client) + token = var.set(client) try: yield finally: - _chunked_async_client.reset(token) + var.reset(token) def get_active_client() -> httpx.Client | None: @@ -325,8 +341,8 @@ def get_active_client() -> httpx.Client | None: Returns ------- httpx.Client or None - The client published by :func:`_publish_client` if currently - inside a :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. + The client published via :func:`_publish` if currently inside a + :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. """ return _chunked_client.get() @@ -1069,13 +1085,18 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: """ Decide whether ``exc`` is a transient worth an automatic retry. - Narrower than :func:`_classify_chunk_error`: it retries rate limits - (429), service errors (5xx), and genuine transport transients - (:class:`httpx.TransportError` — ``ConnectError``, ``ReadTimeout``, …) - but NOT :class:`httpx.InvalidURL` (a too-long server cursor URL won't - fix on retry, though it stays *resumable*). Walks the ``__cause__`` - chain because ``_walk_pages`` re-wraps mid-pagination failures as - ``RuntimeError``. + Inspects only the *top-level* exception, by design — and so is + deliberately narrower than :func:`_classify_chunk_error`, which walks + the ``__cause__`` chain for resumability. ``_paginate`` raises an + initial-request transient (429 / 5xx / :class:`httpx.TransportError` + such as ``ConnectError`` / ``ReadTimeout``) *raw*, but re-wraps any + mid-pagination failure as a ``RuntimeError``. Retrying only the raw, + top-level transient means we re-issue a sub-request that made no + progress (cheap), while a failure after partial pagination escalates + to the resumable :class:`ChunkInterrupted` instead of being re-walked + from page 1 — which would re-spend the very quota that was exhausted. + ``httpx.InvalidURL`` is excluded (a too-long cursor won't fix on + retry), and it only ever arises on a follow-up page anyway. Returns ------- @@ -1083,13 +1104,10 @@ def _retryable(exc: BaseException) -> tuple[bool, float | None]: ``(retryable, retry_after)`` — the server ``Retry-After`` hint (seconds) when the transient carried one, else ``None``. """ - cur: BaseException | None = exc - while cur is not None: - if isinstance(cur, (RateLimited, ServiceUnavailable)): - return True, cur.retry_after - if isinstance(cur, httpx.TransportError): - return True, None - cur = cur.__cause__ + if isinstance(exc, (RateLimited, ServiceUnavailable)): + return True, exc.retry_after + if isinstance(exc, httpx.TransportError): + return True, None return False, None @@ -1334,6 +1352,10 @@ def __init__( # subsequent ``resume()`` only re-issues the missing indices. # On the serial path this fills contiguously from 0. self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} + # Explicit completion order for response-header aggregation. + # Keeping this separate from ``_chunks`` avoids coupling that + # behavior to dict insertion semantics or future write patterns. + self._completion_order: list[int] = [] def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: """Record a completed sub-request's ``(frame, response)`` pair @@ -1341,6 +1363,8 @@ def record(self, index: int, pair: tuple[pd.DataFrame, httpx.Response]) -> None: :meth:`resume` and the parallel fan-out in :func:`_fan_out_async` so the completion set stays encapsulated.""" + if index not in self._chunks: + self._completion_order.append(index) self._chunks[index] = pair def wrap_failure(self, exc: BaseException) -> ChunkInterrupted | None: @@ -1369,6 +1393,27 @@ def completed_chunks(self) -> int: def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: return [self._chunks[i] for i in sorted(self._chunks)] + def _responses_by_completion(self) -> list[httpx.Response]: + # The final element is the most-recently completed sub-request, whose + # headers carry the freshest ``x-ratelimit-remaining`` for aggregation. + return [self._chunks[i][1] for i in self._completion_order] + + def combined(self) -> tuple[pd.DataFrame, httpx.Response]: + """Combine every recorded sub-request into one ``(frame, response)``. + + Frames concatenate in sub-args *index* order (deterministic, + independent of parallel completion order); the aggregated response + takes its headers from the most-recently-*completed* sub-request, so + a fan-out that finished chunks out of index order still surfaces the + latest rate-limit state the server reported rather than a stale one. + """ + return ( + _combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]), + _combine_chunk_responses( + self._responses_by_completion(), self.plan.canonical_url + ), + ) + @property def partial_frame(self) -> pd.DataFrame: """ @@ -1405,7 +1450,7 @@ def partial_response(self) -> httpx.Response | None: if not self._chunks: return None return _combine_chunk_responses( - [resp for _, resp in self._ordered_chunks()], self.plan.canonical_url + self._responses_by_completion(), self.plan.canonical_url ) def resume(self) -> tuple[pd.DataFrame, httpx.Response]: @@ -1443,23 +1488,18 @@ def resume(self) -> tuple[pd.DataFrame, httpx.Response]: is on ``exc.call`` — wait for the underlying condition to clear and call ``exc.call.resume()`` again. """ - with httpx.Client(**HTTPX_DEFAULTS) as client, _publish_client(client): - reporter = _progress.current() - if reporter is not None: - reporter.set_chunks(self.plan.total) - for i, sub_args in enumerate(self.plan.iter_sub_args()): - if i in self._chunks: - continue + with httpx.Client(**HTTPX_DEFAULTS) as client: + with _publish(_chunked_client, client): + reporter = _progress.current() if reporter is not None: - reporter.start_chunk(i + 1) - self._issue(i, sub_args) - ordered = self._ordered_chunks() - frames = [frame for frame, _ in ordered] - responses = [resp for _, resp in ordered] - return ( - _combine_chunk_frames(frames), - _combine_chunk_responses(responses, self.plan.canonical_url), - ) + reporter.set_chunks(self.plan.total) + for i, sub_args in enumerate(self.plan.iter_sub_args()): + if i in self._chunks: + continue + if reporter is not None: + reporter.start_chunk(i + 1) + self._issue(i, sub_args) + return self.combined() def _issue(self, index: int, sub_args: dict[str, Any]) -> None: """ @@ -1556,13 +1596,17 @@ async def _fan_out_async( limits = httpx.Limits( max_connections=max_concurrent, max_keepalive_connections=max_concurrent ) - # ``sys.maxsize`` stands in for "unbounded": ``asyncio.Semaphore`` - # only decrements a counter, never preallocates slots. - semaphore = asyncio.Semaphore(max_concurrent or sys.maxsize) + # ``None`` means "unbounded"; ``sys.maxsize`` stands in for it since + # ``asyncio.Semaphore`` only decrements a counter, never preallocates + # slots. Test ``is None`` explicitly so a stray ``0`` isn't silently + # promoted to unbounded by a falsy-``or``. + semaphore = asyncio.Semaphore( + sys.maxsize if max_concurrent is None else max_concurrent + ) call = ChunkedCall(plan, fetch_once, retry_policy) async with httpx.AsyncClient(limits=limits, **HTTPX_DEFAULTS) as client: - with _publish_async_client(client): + with _publish(_chunked_async_client, client): reporter = _progress.current() if reporter is not None: reporter.set_chunks(plan.total) @@ -1586,15 +1630,16 @@ async def track( # Dispatch every sub-request concurrently. ``return_exceptions`` # keeps completed pairs after a sibling fails, so partial state # stays recoverable via ``ChunkedCall.resume()``. Failure - # precedence: + # precedence, in order: # 1. Cancellation / interrupt signals (CancelledError, # KeyboardInterrupt, SystemExit — non-Exception) propagate # unmodified; wrapping them as a transient would swallow the # user's stop signal. - # 2. Recognized transients wrap as ChunkInterrupted so the user - # gets a resumable handle even when a non-transient failure - # landed earlier in submission order. - # 3. Otherwise re-raise the first failure, preserving its type. + # 2. A non-transient failure (a real bug — unrecognized by + # ``wrap_failure``) surfaces raw, so it isn't masked behind a + # resumable handle for a transient sibling that landed later. + # 3. Only when every failure is a recognized transient do we + # raise the first as a resumable ``ChunkInterrupted``. results = await asyncio.gather( *(track(i, args) for i, args in enumerate(sub_args_list)), return_exceptions=True, @@ -1603,17 +1648,18 @@ async def track( for exc in failures: if not isinstance(exc, Exception): raise exc + first_transient: tuple[ChunkInterrupted, BaseException] | None = None for exc in failures: - if (interrupted := call.wrap_failure(exc)) is not None: - raise interrupted from exc - if failures: - raise failures[0] - - ordered = call._ordered_chunks() - return ( - _combine_chunk_frames([df for df, _ in ordered]), - _combine_chunk_responses([resp for _, resp in ordered], plan.canonical_url), - ) + interrupted = call.wrap_failure(exc) + if interrupted is None: + raise exc + if first_transient is None: + first_transient = (interrupted, exc) + if first_transient is not None: + interrupted, exc = first_transient + raise interrupted from exc + + return call.combined() def multi_value_chunked( diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index ee129aaa..e9500bb4 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -1267,6 +1267,25 @@ async def fetch_async(args): assert len(df) == len(seen_args) +def test_async_fan_out_aggregates_headers_from_latest_completion(monkeypatch): + """Aggregated headers reflect the most recently completed chunk. + + Completion order can differ from index order in parallel mode, so + rate-limit headers should come from whichever sub-request finished + last, not from the highest sub-args index. + """ + + async def fetch_async(args): + if "S1" * 10 in args["sites"]: + await asyncio.sleep(0.02) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=11) + return pd.DataFrame({"id": [_atom_id(args)]}), _ok_response(remaining=77) + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + _, response = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) + assert response.headers[_QUOTA_HEADER] == "11" + + def test_async_fan_out_failure_yields_resumable_call(monkeypatch): """A transient 5xx mid-fan-out raises ``ServiceInterrupted`` whose ``.call`` is a ``ChunkedCall`` holding the completed sub-requests @@ -1491,6 +1510,24 @@ def test_retry_policy_from_env(monkeypatch): RetryPolicy.from_env() +def test_retry_policy_rejects_invalid_settings(): + with pytest.raises(ValueError): + RetryPolicy(max_retries=-1) + with pytest.raises(ValueError): + RetryPolicy(base_backoff=-0.5) + with pytest.raises(ValueError): + RetryPolicy(max_backoff=-1.0) + + +def test_retry_policy_from_env_honors_monkeypatched_constants(monkeypatch): + # The timing knobs are read from the module constants at call time, so + # monkeypatching them (as the module comment promises) takes effect. + monkeypatch.setattr(_chunking, "_RETRY_MAX_BACKOFF", 0.0) + monkeypatch.setattr(_chunking, "_RETRY_BASE_BACKOFF", 0.0) + policy = RetryPolicy.from_env() + assert policy.max_backoff == 0.0 and policy.base_backoff == 0.0 + + # -- _retryable taxonomy ---------------------------------------------------- @@ -1505,8 +1542,13 @@ def test_retryable_taxonomy(): assert _retryable(RuntimeError("400")) == (False, None) -def test_retryable_walks_cause_chain(): - assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (True, 3.0) +def test_retryable_skips_wrapped_midpagination_transient(): + # A transient surfaced mid-pagination is re-wrapped as RuntimeError by + # _paginate; it must NOT be auto-retried (re-walking from page 1 would + # re-spend quota) — it escalates to the resumable handle instead. Only + # the raw, top-level (initial-request) transient is retryable. + assert _retryable(_wrap_cause(RateLimited("429", retry_after=3.0))) == (False, None) + assert _retryable(RateLimited("429", retry_after=3.0)) == (True, 3.0) # -- sync driver ------------------------------------------------------------ @@ -1650,3 +1692,25 @@ async def fetch_async(args): fetch = _async_chunked_fetch(monkeypatch, fetch_async) df, _ = fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) assert len(df) > 1 # every sub-args atom recovered after the retry + + +def test_async_fan_out_surfaces_fatal_over_transient(monkeypatch): + """A non-transient bug in one sub-request surfaces raw rather than + being masked behind a resumable interruption from a transient sibling.""" + monkeypatch.setenv("API_USGS_RETRIES", "2") + + async def _noslept(_d): + return None + + monkeypatch.setattr(_chunking, "_ASLEEP", _noslept) + + async def fetch_async(args): + # One chunk carries a deterministic programmer error; the rest are + # transient. The real bug must win over the resumable transient. + if "S1" * 10 in args["sites"]: + raise ValueError("deterministic bug") + raise ServiceUnavailable("503: transient") + + fetch = _async_chunked_fetch(monkeypatch, fetch_async) + with pytest.raises(ValueError, match="deterministic bug"): + fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index 30be56a2..a98dc76a 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -78,6 +78,26 @@ def test_note_retry_renders_then_clears_on_next_page(): assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] +def test_note_retry_subsecond_wait_shows_decimal(): + # A sub-second backoff must not collapse to a misleading "0s". + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.note_retry(attempt=1, wait=0.3) + out = stream.getvalue() + assert "waiting 0.3s" in out and "waiting 0s" not in out + + +def test_note_retry_cleared_on_close(): + # An exhausted retry leaves retry_note set with no following page; + # close() must clear it so the persisted last line isn't a stale note. + stream = io.StringIO() + reporter = ProgressReporter(stream=stream, enabled=True) + reporter.add_page(rows=1) + reporter.note_retry(attempt=3, wait=5.0) + reporter.close() + assert "retrying" not in stream.getvalue().rsplit("\r", 1)[-1] + + def test_note_retry_is_noop_when_disabled(): stream = io.StringIO() reporter = ProgressReporter(stream=stream, enabled=False)