diff --git a/dataretrieval/nadp.py b/dataretrieval/nadp.py index 370b4927..3d1ee442 100644 --- a/dataretrieval/nadp.py +++ b/dataretrieval/nadp.py @@ -34,7 +34,9 @@ import warnings import zipfile -import requests +import httpx + +from dataretrieval.utils import HTTPX_DEFAULTS _DEPRECATION_MESSAGE = ( "The `nadp` module is deprecated and will be removed from `dataretrieval` " @@ -213,7 +215,7 @@ def get_zip(url, filename): """ _warn_deprecated() - req = requests.get(url + filename) + req = httpx.get(url + filename, **HTTPX_DEFAULTS) req.raise_for_status() # z = zipfile.ZipFile(io.BytesIO(req.content)) diff --git a/dataretrieval/nldi.py b/dataretrieval/nldi.py index 8304c397..e54ceb85 100644 --- a/dataretrieval/nldi.py +++ b/dataretrieval/nldi.py @@ -20,7 +20,7 @@ def _query_nldi(url, query_params, error_message): # A helper function to query the NLDI API response = query(url, payload=query_params) if response.status_code != 200: - raise ValueError(f"{error_message}. Error reason: {response.reason}") + raise ValueError(f"{error_message}. Error reason: {response.reason_phrase}") response_data = {} try: @@ -453,6 +453,14 @@ def _validate_data_source(data_source: str): available_data_sources = _query_nldi( url, {}, "Error getting available data sources" ) + if not isinstance(available_data_sources, list) or not all( + isinstance(ds, dict) and "source" in ds for ds in available_data_sources + ): + raise ValueError( + "NLDI data-source catalog returned an unexpected shape; " + "expected a list of {'source': ..., ...} objects, got: " + f"{available_data_sources!r}" + ) _AVAILABLE_DATA_SOURCES = [ds["source"] for ds in available_data_sources] if data_source not in _AVAILABLE_DATA_SOURCES: diff --git a/dataretrieval/nwis.py b/dataretrieval/nwis.py index 25384d5e..cfcdc64e 100644 --- a/dataretrieval/nwis.py +++ b/dataretrieval/nwis.py @@ -11,8 +11,8 @@ import warnings from json import JSONDecodeError +import httpx import pandas as pd -import requests from dataretrieval.rdb import read_rdb from dataretrieval.utils import BaseMetadata @@ -110,7 +110,7 @@ def wrapper(*args, **kwargs): return wrapper -def _parse_json_or_raise(response: requests.Response) -> pd.DataFrame: +def _parse_json_or_raise(response: httpx.Response) -> pd.DataFrame: """Parse a JSON NWIS response, raising a helpful error on HTML responses.""" try: return _read_json(response.json()) @@ -364,9 +364,7 @@ def get_stats( @_deprecated -def query_waterdata( - service: str, ssl_check: bool = True, **kwargs -) -> requests.models.Response: +def query_waterdata(service: str, ssl_check: bool = True, **kwargs) -> httpx.Response: """ Queries waterdata. @@ -382,7 +380,7 @@ def query_waterdata( Returns ------- - request: ``requests.models.Response`` + request: ``httpx.Response`` The response object from the API request to the web service """ major_params = ["site_no", "state_cd"] @@ -412,7 +410,7 @@ def query_waterdata( @_deprecated def query_waterservices( service: str, ssl_check: bool = True, **kwargs -) -> requests.models.Response: +) -> httpx.Response: """ Queries waterservices.usgs.gov @@ -451,7 +449,7 @@ def query_waterservices( Returns ------- - request: ``requests.models.Response`` + request: ``httpx.Response`` The response object from the API request to the web service """ @@ -1123,7 +1121,7 @@ class NWIS_Metadata(BaseMetadata): Response url query_time: datetme.timedelta Response elapsed time - header: requests.structures.CaseInsensitiveDict + header: httpx.Headers Response headers comments: str | None Metadata comments, if any @@ -1143,7 +1141,7 @@ def __init__(self, response, **parameters) -> None: Parameters ---------- response: Response - Response object from requests module + Response object from httpx module parameters: unpacked dictionary Unpacked dictionary of the parameters supplied in the request diff --git a/dataretrieval/streamstats.py b/dataretrieval/streamstats.py index 7cddabaa..9a27f936 100644 --- a/dataretrieval/streamstats.py +++ b/dataretrieval/streamstats.py @@ -7,7 +7,9 @@ import json -import requests +import httpx + +from dataretrieval.utils import HTTPX_DEFAULTS def download_workspace(workspaceID, format=""): @@ -32,7 +34,7 @@ def download_workspace(workspaceID, format=""): payload = {"workspaceID": workspaceID, "format": format} url = "https://streamstats.usgs.gov/streamstatsservices/download" - r = requests.get(url, params=payload) + r = httpx.get(url, params=payload, **HTTPX_DEFAULTS) r.raise_for_status() return r @@ -125,7 +127,7 @@ def get_watershed( } url = "https://streamstats.usgs.gov/streamstatsservices/watershed.geojson" - r = requests.get(url, params=payload) + r = httpx.get(url, params=payload, **HTTPX_DEFAULTS) r.raise_for_status() diff --git a/dataretrieval/utils.py b/dataretrieval/utils.py index 76bbb6ad..7bb03a69 100644 --- a/dataretrieval/utils.py +++ b/dataretrieval/utils.py @@ -5,12 +5,17 @@ import warnings from collections.abc import Iterable +import httpx import pandas as pd -import requests import dataretrieval from dataretrieval.codes import tz +HTTPX_DEFAULTS = { + "follow_redirects": True, + "timeout": httpx.Timeout(60.0, connect=10.0), +} + def to_str(listlike, delimiter=","): """Translates list-like objects into strings. @@ -205,7 +210,7 @@ class BaseMetadata: Response url query_time: datetme.timedelta Response elapsed time - header: requests.structures.CaseInsensitiveDict + header: httpx.Headers Response headers """ @@ -216,7 +221,7 @@ def __init__(self, response) -> None: Parameters ---------- response: Response - Response object from requests module + Response object from httpx module Returns ------- @@ -225,8 +230,8 @@ def __init__(self, response) -> None: """ - # These are built from the API response - self.url = response.url + # Coerce httpx.URL -> str: BaseMetadata.url has always been str. + self.url = str(response.url) self.query_time = response.elapsed self.header = response.headers self.comment = None @@ -254,10 +259,29 @@ def __repr__(self) -> str: return f"{type(self).__name__}(url={self.url})" +_URL_TOO_LONG_EXAMPLE = """ + # n is the number of chunks to divide the query into \n + split_list = np.array_split(site_list, n) + data_list = [] # list to store chunk results in \n + # loop through chunks and make requests \n + for site_list in split_list: \n + data = nwis.get_record(sites=site_list, service='dv', \n + start=start, end=end) \n + data_list.append(data) # append results to list""" + + +def _url_too_long_error(detail: str) -> ValueError: + return ValueError( + "Request URL too long. Modify your query to use fewer sites. " + f"{detail}. Pseudo-code example of how to split your query: " + f"\n {_URL_TOO_LONG_EXAMPLE}" + ) + + def query(url, payload, delimiter=",", ssl_check=True): """Send a query. - Wrapper for requests.get that handles errors, converts listed + Wrapper for httpx.get that handles errors, converts listed query parameters to comma separated strings, and returns response. Parameters @@ -265,7 +289,7 @@ def query(url, payload, delimiter=",", ssl_check=True): url: string URL to query payload: dict - query parameters passed to ``requests.get`` + query parameters passed to ``httpx.get`` delimiter: string delimiter to use with lists ssl_check: bool @@ -275,19 +299,27 @@ def query(url, payload, delimiter=",", ssl_check=True): Returns ------- string: query response - The response from the API query ``requests.get`` function call. + The response from the API query ``httpx.get`` function call. """ for key, value in payload.items(): payload[key] = to_str(value, delimiter) - # for index in range(len(payload)): - # key, value = payload[index] - # payload[index] = (key, to_str(value)) + # httpx serializes None params as ``foo=``; USGS rejects with 400. + # Drop them. (``to_str`` returns None for non-iterable scalars like bools.) + payload = {k: v for k, v in payload.items() if v is not None} - # define the user agent for the query user_agent = {"user-agent": f"python-dataretrieval/{dataretrieval.__version__}"} - response = requests.get(url, params=payload, headers=user_agent, verify=ssl_check) + try: + response = httpx.get( + url, + params=payload, + headers=user_agent, + verify=ssl_check, + **HTTPX_DEFAULTS, + ) + except httpx.InvalidURL as exc: + raise _url_too_long_error(f"httpx rejected the URL client-side: {exc}") from exc if response.status_code == 400: raise ValueError( @@ -299,24 +331,10 @@ def query(url, payload, delimiter=",", ssl_check=True): + f"URL: {response.url}" ) elif response.status_code == 414: - _reason = response.reason - _example = """ - # n is the number of chunks to divide the query into \n - split_list = np.array_split(site_list, n) - data_list = [] # list to store chunk results in \n - # loop through chunks and make requests \n - for site_list in split_list: \n - data = nwis.get_record(sites=site_list, service='dv', \n - start=start, end=end) \n - data_list.append(data) # append results to list""" - raise ValueError( - "Request URL too long. Modify your query to use fewer sites. " - + f"API response reason: {_reason}. Pseudo-code example of how to " - + f"split your query: \n {_example}" - ) - elif response.status_code in [500, 502, 503]: + raise _url_too_long_error(f"API response reason: {response.reason_phrase}") + elif 500 <= response.status_code < 600: raise ValueError( - f"Service Unavailable: {response.status_code} {response.reason}. " + f"Service Unavailable: {response.status_code} {response.reason_phrase}. " + f"The service at {response.url} may be down or experiencing issues." ) diff --git a/dataretrieval/waterdata/api.py b/dataretrieval/waterdata/api.py index 6f24d80f..57fffc88 100644 --- a/dataretrieval/waterdata/api.py +++ b/dataretrieval/waterdata/api.py @@ -13,11 +13,15 @@ from typing import get_args from urllib.parse import quote +import httpx import pandas as pd -import requests -from requests.models import PreparedRequest -from dataretrieval.utils import BaseMetadata, _attach_datetime_columns, to_str +from dataretrieval.utils import ( + HTTPX_DEFAULTS, + BaseMetadata, + _attach_datetime_columns, + to_str, +) from dataretrieval.waterdata.filters import FILTER_LANG from dataretrieval.waterdata.types import ( CODE_SERVICES, @@ -2110,7 +2114,7 @@ def get_codes(code_service: CODE_SERVICES) -> pd.DataFrame: url = f"{SAMPLES_URL}/codeservice/{code_service}?mimeType=application%2Fjson" - response = requests.get(url, headers=_default_headers()) + response = httpx.get(url, headers=_default_headers(), **HTTPX_DEFAULTS) response.raise_for_status() @@ -2336,12 +2340,14 @@ def get_samples( url = f"{SAMPLES_URL}/{service}/{profile}" - req = PreparedRequest() - req.prepare_url(url, params=params) - logger.debug("Request: %s", req.url) + logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params)) - response = requests.get( - url, params=params, verify=ssl_check, headers=_default_headers() + response = httpx.get( + url, + params=params, + verify=ssl_check, + headers=_default_headers(), + **HTTPX_DEFAULTS, ) response.raise_for_status() @@ -2408,12 +2414,14 @@ def get_samples_summary( url = f"{SAMPLES_URL}/summary/{quote(monitoringLocationIdentifier, safe='')}" params = {"mimeType": "text/csv"} - req = PreparedRequest() - req.prepare_url(url, params=params) - logger.debug("Request: %s", req.url) + logger.debug("Request: %s", httpx.URL(url).copy_merge_params(params)) - response = requests.get( - url, params=params, verify=ssl_check, headers=_default_headers() + response = httpx.get( + url, + params=params, + verify=ssl_check, + headers=_default_headers(), + **HTTPX_DEFAULTS, ) response.raise_for_status() diff --git a/dataretrieval/waterdata/chunking.py b/dataretrieval/waterdata/chunking.py index a6fee155..36ee24fd 100644 --- a/dataretrieval/waterdata/chunking.py +++ b/dataretrieval/waterdata/chunking.py @@ -9,18 +9,12 @@ sub-request URL fits. Requests that already fit get a trivial single-step plan — ``ChunkedCall`` has one code path either way. -Quota: after the first sub-request ``ChunkedCall`` reads -``x-ratelimit-remaining``; if the rest of the plan won't fit, it -raises ``RequestExceedsQuota`` before burning more budget. Set -``API_USGS_LIMIT=0`` to skip this pre-emptive check and attempt the -full plan anyway. - Interruption: any mid-stream transient failure (429, 5xx) surfaces as a ``ChunkInterrupted`` subclass — ``QuotaExhausted`` for 429, ``ServiceInterrupted`` for 5xx. The exception carries ``.call``, a ``ChunkedCall`` handle that owns the already-completed sub-request -state. Call ``.call.resume()`` once the underlying condition clears -to resume; only the still-pending sub-requests are re-issued. +state. Call ``.call.resume()`` once the underlying condition +clears; only the still-pending sub-requests are re-issued. ``Retry-After`` (when the server sets it) is surfaced on the exception as ``.retry_after``. @@ -37,17 +31,18 @@ import functools import itertools import math -import os from collections.abc import Callable, Iterator -from contextlib import contextmanager +from contextlib import contextmanager, suppress from contextvars import ContextVar from dataclasses import dataclass +from datetime import timedelta from typing import Any, ClassVar from urllib.parse import quote_plus +import httpx import pandas as pd -import requests -from requests.structures import CaseInsensitiveDict + +from dataretrieval.utils import HTTPX_DEFAULTS from . import _progress from .filters import ( @@ -98,45 +93,45 @@ # Response header USGS uses to advertise remaining hourly quota. _QUOTA_HEADER = "x-ratelimit-remaining" -# Session shared across all sub-requests of a single chunked call so +# Client shared across all sub-requests of a single chunked call so # paginated-loop helpers downstream (``_walk_pages``) reuse one -# connection pool across the whole fan-out. ``None`` when not inside a +# connection pool across the whole call. ``None`` when not inside a # chunked call — paginated helpers fall back to their own short-lived -# session in that case. -_chunked_session: ContextVar[requests.Session | None] = ContextVar( - "_chunked_session", default=None +# client in that case. +_chunked_client: ContextVar[httpx.Client | None] = ContextVar( + "_chunked_client", default=None ) @contextmanager -def _publish_session(session: requests.Session) -> Iterator[None]: +def _publish_client(client: httpx.Client) -> Iterator[None]: """ - Make ``session`` visible to :func:`get_active_session` for the - duration of the ``with`` block via the ``_chunked_session`` + Make ``client`` visible to :func:`get_active_client` for the + duration of the ``with`` block via the ``_chunked_client`` ContextVar. Wraps the set/reset token dance so callers don't have to. """ - token = _chunked_session.set(session) + token = _chunked_client.set(client) try: yield finally: - _chunked_session.reset(token) + _chunked_client.reset(token) -def get_active_session() -> requests.Session | None: +def get_active_client() -> httpx.Client | None: """ - Return the chunker's currently-published session, or ``None``. + Return the chunker's currently-published sync client, or ``None``. - Public accessor for the ``_chunked_session`` ContextVar so - sibling modules (notably :func:`dataretrieval.waterdata.utils._session`) + Public accessor for the ``_chunked_client`` ContextVar so + sibling modules (notably :func:`dataretrieval.waterdata.utils._client_for`) don't have to reach into the private ContextVar directly. Returns ------- - requests.Session or None - The session published by :func:`_publish_session` if currently + httpx.Client or None + The client published by :func:`_publish_client` if currently inside a :class:`ChunkedCall` ``resume`` block; ``None`` otherwise. """ - return _chunked_session.get() + return _chunked_client.get() # Separators the two axis kinds use to join their atoms back into @@ -145,7 +140,7 @@ def get_active_session() -> requests.Session | None: _LIST_SEP = "," _OR_SEP = " OR " -_FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, requests.Response]] +_FetchOnce = Callable[[dict[str, Any]], tuple[pd.DataFrame, httpx.Response]] class _RetryableTransportError(RuntimeError): @@ -211,60 +206,6 @@ class RequestTooLarge(ValueError): """ -class RequestExceedsQuota(ValueError): - """ - Remaining rate-limit window can't cover the rest of the chunked plan. - - Raised after a sub-request when ``x-ratelimit-remaining`` in the - response shows the rest of the plan can't fit in the current per-key - rate-limit window. The chunks completed so far have already been - issued and consumed quota; ``ChunkedCall`` stops here rather than - burn more quota on a call that will fail mid-way. The completed - work is preserved on ``.call`` (the originating ``ChunkedCall``) - so callers can recover its ``partial_frame`` / ``partial_response`` - and, once the rate-limit window resets, call ``.call.resume()`` - to continue. - - Attributes - ---------- - planned_chunks : int - Total sub-requests the joint plan would issue. - available : int - Sub-requests this caller can still issue in the current window - (``x-ratelimit-remaining`` + chunks already completed). - deficit : int - ``planned_chunks - available`` — how far over budget the call - would run if it continued. - call : ChunkedCall or None - The originating call handle. ``None`` on hand-constructed - exceptions (test fixtures); otherwise the live handle whose - ``partial_frame`` / ``partial_response`` expose the work - completed before the check fired and whose ``resume()`` can be - called once the rate-limit window rolls over. - """ - - def __init__( - self, - *, - planned_chunks: int, - available: int, - deficit: int, - call: ChunkedCall | None = None, - ) -> None: - super().__init__( - f"Request would issue {planned_chunks} sub-requests but only " - f"{available} fit in the current rate-limit window (short by " - f"{deficit}). Wait for the window to reset, request a higher " - f"per-key quota, narrow the query, or set " - f"API_USGS_LIMIT=0 to bypass this check and risk a " - f"mid-stream 429 (recoverable via QuotaExhausted.resume())." - ) - self.planned_chunks = planned_chunks - self.available = available - self.deficit = deficit - self.call = call - - class ChunkInterrupted(RuntimeError): """ Base class for mid-stream chunk failures whose completed work is @@ -302,7 +243,7 @@ class ChunkInterrupted(RuntimeError): was raised. Snapshot at raise time — does NOT advance on a later ``call.resume()`` (use ``exc.call.partial_frame`` for the live view). - partial_response : requests.Response or None + partial_response : httpx.Response or None Aggregated response covering the completed sub-requests at raise time; ``None`` if nothing had completed yet. Same snapshot semantics as ``partial_frame``. @@ -346,12 +287,15 @@ def __init__( total_chunks: int, call: ChunkedCall | None = None, retry_after: float | None = None, + cause: BaseException | None = None, ) -> None: - super().__init__( - self._MESSAGE_TEMPLATE.format( - completed_chunks=completed_chunks, total_chunks=total_chunks - ) + message = self._MESSAGE_TEMPLATE.format( + completed_chunks=completed_chunks, total_chunks=total_chunks ) + if cause is not None: + cause_msg = str(cause) or type(cause).__name__ + message = f"{message} Cause: {type(cause).__name__}: {cause_msg}" + super().__init__(message) self.completed_chunks = completed_chunks self.total_chunks = total_chunks self.call = call @@ -365,7 +309,7 @@ def __init__( # already comes via ``copy.copy`` from ``_combine_chunk_responses``. if call is None: self.partial_frame: pd.DataFrame = pd.DataFrame() - self.partial_response: requests.Response | None = None + self.partial_response: httpx.Response | None = None else: self.partial_frame = call.partial_frame.copy() self.partial_response = call.partial_response @@ -376,17 +320,10 @@ class QuotaExhausted(ChunkInterrupted): A sub-request returned HTTP 429 — the per-key rate-limit window is exhausted. Subclass of :class:`ChunkInterrupted`. - For a chunked call (``total_chunks > 1``) reached past chunk 0, - the post-first-chunk :class:`RequestExceedsQuota` check normally - short-circuits before burning quota on a plan that won't fit; - arrival here typically means a concurrent caller drained the - window faster than predicted. ``partial_frame`` holds what - completed first. - - For a single-shot call (``total_chunks == 1``) or a 429 on the - very first chunk, ``partial_frame`` is empty and - ``partial_response`` is ``None``; the original ``RateLimited`` is - on ``__cause__``. + The completed sub-requests are preserved on ``.call``; once the + rate-limit window resets, ``.call.resume()`` re-issues only the + still-pending work. ``partial_frame`` holds what completed + before the 429. """ _MESSAGE_TEMPLATE = ( @@ -414,47 +351,112 @@ class ServiceInterrupted(ChunkInterrupted): ) -def _request_bytes(req: requests.PreparedRequest) -> int: +def _request_bytes(req: httpx.Request) -> int: + """ + Return the total bytes of an httpx request: URL + body. + + GET routes have empty ``.content`` and reduce to URL length. POST + routes (CQL2 JSON body) need body bytes — the URL stays short + regardless of payload, so URL-only sizing would underestimate the + request and skip chunking when it's needed. + + Parameters + ---------- + req : httpx.Request + The request to size. + + Returns + ------- + int + ``len(str(req.url)) + len(req.content)``. ``httpx.URL`` doesn't + support ``len()`` directly, so the str-coercion is required. + """ + return len(str(req.url)) + len(req.content) + + +def _safe_request_bytes( + build_request: Callable[..., httpx.Request], + args: dict[str, Any], + url_limit: int, +) -> int: """ - Total bytes of a prepared request: URL + body. + Size a candidate sub-request, treating ``httpx.InvalidURL`` as + "still too large". - GET routes have ``body=None`` and reduce to URL length. POST routes - (CQL2 JSON body) need body bytes — the URL stays short regardless - of payload, so URL-only sizing would underestimate the request and - skip chunking when it's needed. + ``httpx.URL`` enforces a hard 64 KB cap per URL component + (``MAX_URL_LENGTH``) and raises ``httpx.InvalidURL`` for anything + bigger. We report ``url_limit + 1`` on overflow so the greedy + halving loop in :meth:`ChunkPlan._plan` keeps shrinking the + largest axis until ``httpx.Request`` can be constructed at all. Parameters ---------- - req : requests.PreparedRequest - The prepared request to size. + build_request : Callable[..., httpx.Request] + Factory that turns a kwargs dict into a sized request. + args : dict[str, Any] + Per-sub-request kwargs to pass through to ``build_request``. + url_limit : int + The chunker's byte budget; returned + 1 on overflow. Returns ------- int - ``len(req.url) + len(req.body)`` where ``req.body`` is treated - as 0 bytes when ``None`` and UTF-8 encoded when ``str``. + Real byte count when the request builds, otherwise + ``url_limit + 1`` so the planner's "too large" branch keeps + halving. + """ + try: + req = build_request(**args) + except httpx.InvalidURL: + return url_limit + 1 + return _request_bytes(req) - Raises - ------ - TypeError - If ``req.body`` is not ``None``, ``bytes``/``bytearray``, or - ``str``. Size-based planning needs a deterministic byte count, - so generators and file-like streams are rejected up front - rather than silently treated as zero bytes. + +def _safe_elapsed(response: httpx.Response) -> timedelta: """ - body = req.body - if body is None: - body_len = 0 - elif isinstance(body, (bytes, bytearray)): - body_len = len(body) - elif isinstance(body, str): - body_len = len(body.encode("utf-8")) - else: - raise TypeError( - f"multi_value_chunked cannot size a request body of type " - f"{type(body).__name__!r}; pass str, bytes, or None." + Read ``response.elapsed``, falling back to ``timedelta(0)`` when + the attribute hasn't been populated. + + httpx only writes ``.elapsed`` when a response is closed through + its normal transport path. ``MockTransport`` (used by + ``pytest-httpx``) and hand-constructed ``httpx.Response`` objects + leave the attribute unset, so accessing it raises ``RuntimeError``. + Combining responses across chunks needs a defined duration, so we + treat the missing attribute as zero elapsed. + """ + try: + return response.elapsed + except RuntimeError: + return timedelta(0) + + +def _set_response_url(response: httpx.Response, url: str | httpx.URL) -> None: + """ + Overwrite the URL surfaced by a response without back-propagating + the change into any aliased original. + + On real ``httpx.Response`` instances ``.url`` is a read-only + property that resolves through the bound request; rather than + mutate the existing request's URL (which would be visible through + any shallow copy that shares the same ``.request``), we replace + the response's request with a fresh :class:`httpx.Request` carrying + the new URL. On lightweight test mocks ``.url`` is a plain + writable attribute — that path is tried first. + """ + try: + response.url = url # type: ignore[misc] + except AttributeError: + target = httpx.URL(str(url)) + try: + old = response.request + except RuntimeError: + # No request bound (some hand-built httpx.Response fixtures); + # synthesize a minimal one to hold the URL. + response.request = httpx.Request("GET", target) + return + response.request = httpx.Request( + method=old.method, url=target, headers=old.headers ) - return len(req.url) + body_len @dataclass(frozen=True) @@ -489,7 +491,8 @@ class _Axis: def chunk_bytes(self, chunk: list[str]) -> int: """ - URL-encoded bytes a chunk contributes when substituted. + Return the URL-encoded byte count this chunk contributes when + substituted into the request. ``quote_plus`` is faithful to what the real URL builder produces, so values containing characters that expand under URL @@ -588,11 +591,11 @@ class ChunkPlan: ---------- args : dict[str, Any] The user-level request kwargs. - build_request : Callable[..., requests.PreparedRequest] - Factory that turns a kwargs dict into a sized prepared - request, e.g. ``_construct_api_requests``. + build_request : Callable[..., httpx.Request] + Factory that turns a kwargs dict into a sized httpx request, + e.g. ``_construct_api_requests``. url_limit : int - Byte budget for the prepared request (URL + body). + Byte budget for the request (URL + body). Attributes ---------- @@ -607,12 +610,10 @@ class ChunkPlan: Per-axis partition: ``chunks[axis.arg_key]`` is the list of atom-sublists this axis is split into. Empty in passthrough. canonical_url : str or None - URL of the full original request, used to overwrite the first - chunk's ``response.url`` so ``BaseMetadata`` reflects the - user's full query. ``None`` on the nothing-to-chunk passthrough - path — ``fetch_once``'s response already carries the canonical - URL there, so ``ChunkedCall`` skips the override to avoid an - extra ``build_request`` call on the hot path. + URL of the user's original (un-chunked) request, used to + overwrite a chunked response's ``.url`` so ``BaseMetadata`` + reflects the full query. ``None`` on the passthrough path + and when no buildable URL exists. Raises ------ @@ -624,7 +625,7 @@ class ChunkPlan: def __init__( self, args: dict[str, Any], - build_request: Callable[..., requests.PreparedRequest], + build_request: Callable[..., httpx.Request], url_limit: int, ) -> None: self.args = args @@ -635,22 +636,45 @@ def __init__( axes = _extract_axes(args) # No chunkable axes → skip ``build_request`` entirely; the # common Water Data call shape shouldn't pay for an unused - # request prep on the passthrough hot path. + # request prep on the passthrough hot path. ``fetch_once`` + # will run with the user's args verbatim; if that produces + # an over-budget URL, the server (or httpx itself) rejects. if not axes: return - initial_request = build_request(**args) - self.canonical_url = initial_request.url - if _request_bytes(initial_request) <= url_limit: - return + # Constructing the initial request can itself trip + # ``httpx.InvalidURL`` (URL > 64 KB) — that's the canonical + # "needs chunking" signal, so swallow it and proceed to plan. + # When the unchunked URL does build, preserve it as + # ``canonical_url`` so ``BaseMetadata.url`` echoes the user's + # original query verbatim; only fall back to a worst-case + # sub-request URL when the URL itself can't be constructed. + try: + initial_request = build_request(**args) + except httpx.InvalidURL: + initial_request = None + + if initial_request is not None: + self.canonical_url = str(initial_request.url) + if _request_bytes(initial_request) <= url_limit: + return self.axes = axes self.chunks = {axis.arg_key: [list(axis.atoms)] for axis in axes} self._plan(build_request, url_limit) + if self.canonical_url is None: + # Original URL was un-constructable (httpx.InvalidURL); fall + # back to the worst-case sub-request URL so + # ``BaseMetadata.url`` still surfaces something + # informative. If even that overflows, leave canonical_url + # as None (set above) and let the response's own URL stand. + with suppress(httpx.InvalidURL): + self.canonical_url = str(build_request(**self._worst_case_args()).url) + def _plan( self, - build_request: Callable[..., requests.PreparedRequest], + build_request: Callable[..., httpx.Request], url_limit: int, ) -> None: """ @@ -668,7 +692,7 @@ def _plan( """ while True: worst = self._worst_case_args() - if _request_bytes(build_request(**worst)) <= url_limit: + if _safe_request_bytes(build_request, worst, url_limit) <= url_limit: return biggest_axis: _Axis | None = None @@ -743,7 +767,7 @@ def iter_sub_args(self) -> Iterator[dict[str, Any]]: sub_args[axis.arg_key] = axis.render(chunk) yield sub_args - def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, requests.Response]: + def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, httpx.Response]: """ Run the plan and return the combined ``(frame, response)``. @@ -760,7 +784,7 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, requests.Respon ------- df : pandas.DataFrame Combined data from every successful sub-request. - response : requests.Response + response : httpx.Response Aggregated response (canonical URL, last page's headers, cumulative elapsed time). @@ -771,54 +795,10 @@ def execute(self, fetch_once: _FetchOnce) -> tuple[pd.DataFrame, requests.Respon (:class:`QuotaExhausted` for 429, :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call``. - RequestExceedsQuota - When the rate-limit window can't cover the remaining plan. """ return ChunkedCall(self, fetch_once).resume() -def _quota_check_disabled() -> bool: - """ - Check whether the pre-emptive quota check is disabled. - - Read at call time (not import time) so test patches via - ``monkeypatch.setenv`` take effect. - - Returns - ------- - bool - ``True`` when the environment variable ``API_USGS_LIMIT`` is - set to ``"0"`` (stripped), bypassing the post-first-chunk - :class:`RequestExceedsQuota` check. - """ - return os.environ.get("API_USGS_LIMIT", "").strip() == "0" - - -def _read_remaining(response: requests.Response) -> int | None: - """ - Parse the ``x-ratelimit-remaining`` header from a response. - - Parameters - ---------- - response : requests.Response - A response that may or may not carry the quota header. - - Returns - ------- - int or None - The parsed integer, or ``None`` when the header is missing or - unparseable. ``ChunkedCall`` treats ``None`` as "no quota - signal" and skips the post-first-chunk plan check. - """ - raw = response.headers.get(_QUOTA_HEADER) - if raw is None: - return None - try: - return int(raw) - except (TypeError, ValueError): - return None - - def _classify_chunk_error( exc: BaseException, ) -> tuple[type[ChunkInterrupted], float | None] | None: @@ -850,11 +830,13 @@ def _classify_chunk_error( ``__cause__``, so this function must walk the chain rather than just ``isinstance`` the top-level exception. - Bare ``requests.exceptions.RequestException`` (ConnectionError, - Timeout, SSLError, …) is also treated as a transient transport - failure and wrapped as :class:`ServiceInterrupted` — these don't - inherit from ``RuntimeError`` and would otherwise escape the - chunker's catch with no resumable handle. + Bare ``httpx.HTTPError`` (``ConnectError``, ``TimeoutException``, + etc.) and ``httpx.InvalidURL`` (server-supplied cursor URL too + long, oversize follow-up) are also treated as transport failures + and wrapped as :class:`ServiceInterrupted` — these don't inherit + from ``RuntimeError`` (and ``InvalidURL`` doesn't even inherit + from ``HTTPError``), so without explicit handling they would + escape the chunker's catch with no resumable handle. """ cur: BaseException | None = exc while cur is not None: @@ -862,7 +844,7 @@ def _classify_chunk_error( return QuotaExhausted, cur.retry_after if isinstance(cur, ServiceUnavailable): return ServiceInterrupted, cur.retry_after - if isinstance(cur, requests.exceptions.RequestException): + if isinstance(cur, (httpx.HTTPError, httpx.InvalidURL)): return ServiceInterrupted, None cur = cur.__cause__ return None @@ -930,59 +912,63 @@ def _combine_chunk_frames(frames: list[pd.DataFrame]) -> pd.DataFrame: def _combine_chunk_responses( - responses: list[requests.Response], canonical_url: str | None -) -> requests.Response: + responses: list[httpx.Response], canonical_url: str | None +) -> httpx.Response: """ Fold per-sub-request responses into a single aggregated response. - Returns a shallow copy of ``responses[0]`` with ``.headers`` set to - the last response's (so ``x-ratelimit-remaining`` reflects current - state), ``.elapsed`` set to total wall-clock across every response, - and ``.url`` set to the canonical original-query URL so - ``BaseMetadata`` reflects the user's full request rather than the - first chunk. + For a multi-response input, returns a shallow copy of + ``responses[0]`` with ``.headers`` set to the last response's (so + ``x-ratelimit-remaining`` reflects current state), ``.elapsed`` set + to total wall-clock across every response, and ``.url`` set to the + canonical original-query URL (when supplied) so ``BaseMetadata`` + reflects the user's full request rather than the first chunk. + + For a single-response input with no canonical-URL override, + ``responses[0]`` is returned unchanged to skip the copy on the + passthrough hot path. Parameters ---------- - responses : list[requests.Response] + responses : list[httpx.Response] One response per completed sub-request, in execution order. canonical_url : str or None URL of the unchunked original request. ``None`` skips the URL - override — used by the trivial-passthrough path where - ``fetch_once`` already returns a response whose ``.url`` is - the original-query URL. + override — used by the passthrough path (``fetch_once``'s + response already carries the original-query URL) and by the + worst-case overflow path (no buildable canonical URL exists). Returns ------- - requests.Response + httpx.Response A shallow copy of the first response with aggregated ``headers``, ``elapsed``, and ``url``. The function is idempotent (the input responses' ``headers`` / ``elapsed`` / ``url`` are never mutated), so it's safe to call repeatedly via :attr:`ChunkedCall.partial_response` during error inspection or resume retries. ``headers`` on the returned - object is a fresh ``CaseInsensitiveDict``, so mutations there - don't back-propagate into any chunk's underlying response. - Note that other ``Response`` fields (``_content``, ``raw``, - ``cookies``, ``request``) are still aliased to the first - chunk by the shallow copy — callers that mutate those will - affect the underlying chunk response. + object is a fresh ``httpx.Headers``, so mutations there don't + back-propagate into any chunk's underlying response. """ + if len(responses) == 1 and canonical_url is None: + return responses[0] + # ``copy.copy`` lets repeated calls re-sum elapsed from scratch # rather than re-mutating ``responses[0]`` in place. The headers - # dict is then rewrapped in a fresh ``CaseInsensitiveDict`` so the + # dict is then rewrapped in a fresh ``httpx.Headers`` so the # aggregate's headers don't share identity with — or leak mutations # back into — any underlying response on ``ChunkedCall._chunks``. head = copy.copy(responses[0]) if len(responses) > 1: - head.headers = CaseInsensitiveDict(responses[-1].headers) + head.headers = httpx.Headers(responses[-1].headers) head.elapsed = sum( - (r.elapsed for r in responses[1:]), start=responses[0].elapsed + (_safe_elapsed(r) for r in responses[1:]), + start=_safe_elapsed(responses[0]), ) else: - head.headers = CaseInsensitiveDict(responses[0].headers) + head.headers = httpx.Headers(responses[0].headers) if canonical_url is not None: - head.url = canonical_url + _set_response_url(head, canonical_url) return head @@ -1000,12 +986,12 @@ class ChunkedCall: executes; callers reach it via :attr:`ChunkInterrupted.call` on the exception raised by a mid-stream failure. - :meth:`resume` is idempotent: it skips sub-requests already - completed (``self.completed_chunks`` is the cursor) and re-issues - only the still-pending ones. The sub-request - ordering matches :meth:`ChunkPlan.iter_sub_args`, which is - deterministic, so each call picks up exactly where the previous - one stopped. + :meth:`resume` is idempotent: it iterates + :meth:`ChunkPlan.iter_sub_args` (deterministic order) and skips + any index whose result is already in ``self._chunks``. The + completion set is a ``dict[int, (df, response)]`` keyed by + sub-args index; a subsequent ``resume`` only re-issues + sub-requests whose index isn't already present. Parameters ---------- @@ -1021,14 +1007,10 @@ class ChunkedCall: The plan being driven (read-only after construction). fetch_once : Callable The per-sub-request fetch function. - completed_chunks : int - Number of sub-requests successfully completed so far. - total_chunks : int - Total sub-requests in ``plan`` (``== plan.total``). partial_frame : pandas.DataFrame Combined frame of completed sub-requests (live; recomputed per access). - partial_response : requests.Response or None + partial_response : httpx.Response or None Aggregated response with canonical URL restored, or ``None`` when nothing has completed yet (live; recomputed per access). """ @@ -1036,19 +1018,12 @@ class ChunkedCall: def __init__(self, plan: ChunkPlan, fetch_once: _FetchOnce) -> None: self.plan = plan self.fetch_once = fetch_once - # One entry per completed sub-request, in execution order. - # A single list keeps the (frame, response) pair atomic so the - # ``len(_chunks)`` cursor can't ever drift between two parallel - # lists. - self._chunks: list[tuple[pd.DataFrame, requests.Response]] = [] + # Completed (frame, response) pairs keyed by sub-args index; + # ``resume()`` skips indices already present. + self._chunks: dict[int, tuple[pd.DataFrame, httpx.Response]] = {} - @property - def completed_chunks(self) -> int: - return len(self._chunks) - - @property - def total_chunks(self) -> int: - return self.plan.total + def _ordered_chunks(self) -> list[tuple[pd.DataFrame, httpx.Response]]: + return [self._chunks[i] for i in sorted(self._chunks)] @property def partial_frame(self) -> pd.DataFrame: @@ -1067,10 +1042,10 @@ def partial_frame(self) -> pd.DataFrame: """ if not self._chunks: return pd.DataFrame() - return _combine_chunk_frames([frame for frame, _ in self._chunks]) + return _combine_chunk_frames([frame for frame, _ in self._ordered_chunks()]) @property - def partial_response(self) -> requests.Response | None: + def partial_response(self) -> httpx.Response | None: """ Aggregated response with the canonical URL restored to the user's full original query. @@ -1079,38 +1054,37 @@ def partial_response(self) -> requests.Response | None: Returns ------- - requests.Response or None + httpx.Response or None Aggregated response when at least one sub-request has completed, ``None`` otherwise. """ if not self._chunks: return None return _combine_chunk_responses( - [resp for _, resp in self._chunks], self.plan.canonical_url + [resp for _, resp in self._ordered_chunks()], self.plan.canonical_url ) - def resume(self) -> tuple[pd.DataFrame, requests.Response]: + def resume(self) -> tuple[pd.DataFrame, httpx.Response]: """ - Drive the chunked call to completion. + Drive the chunked call to completion via the sync ``fetch_once``. - Opens one ``requests.Session`` for the run and publishes it on - the ``_chunked_session`` ``ContextVar`` so paginated-loop + Opens one ``httpx.Client`` for the run and publishes it on + the ``_chunked_client`` ``ContextVar`` so paginated-loop helpers downstream (``_walk_pages``) reuse the same connection pool across every sub-request instead of handshaking fresh on - each. The session is closed when ``resume`` returns or raises; + each. The client is closed when ``resume`` returns or raises; a follow-up ``resume`` call (after a ``ChunkInterrupted``) opens a new one. - Idempotent: starts from chunk 0 on the first call, then from - the cursor (``self.completed_chunks``) on every subsequent - call. Re-issues only sub-requests that haven't already - completed. + Idempotent: only sub-requests whose index isn't already in + ``self._chunks`` are re-issued. Sub-args order matches + :meth:`ChunkPlan.iter_sub_args` and is deterministic. Returns ------- df : pandas.DataFrame Combined data from every successful sub-request. - response : requests.Response + response : httpx.Response Aggregated response (canonical URL, last page's headers, cumulative elapsed time). @@ -1122,41 +1096,44 @@ def resume(self) -> tuple[pd.DataFrame, requests.Response]: :class:`ServiceInterrupted` for 5xx). The resumable handle is on ``exc.call`` — wait for the underlying condition to clear and call ``exc.call.resume()`` again. - RequestExceedsQuota - When the rate-limit window can't cover the remaining plan - (checked after the first sub-request). """ - with requests.Session() as session, _publish_session(session): + with httpx.Client(**HTTPX_DEFAULTS) as client, _publish_client(client): reporter = _progress.current() if reporter is not None: reporter.set_chunks(self.plan.total) - completed = len(self._chunks) for i, sub_args in enumerate(self.plan.iter_sub_args()): - if i < completed: + if i in self._chunks: continue if reporter is not None: reporter.start_chunk(i + 1) - self._issue(sub_args) - frames = [frame for frame, _ in self._chunks] - responses = [resp for _, resp in self._chunks] + self._issue(i, sub_args) + ordered = self._ordered_chunks() + frames = [frame for frame, _ in ordered] + responses = [resp for _, resp in ordered] return ( _combine_chunk_frames(frames), _combine_chunk_responses(responses, self.plan.canonical_url), ) - def _issue(self, sub_args: dict[str, Any]) -> None: - # Catch both ``RuntimeError`` (the layer's typed contract: - # ``RateLimited`` / ``ServiceUnavailable`` / mid-pagination - # wrapper) and ``requests.exceptions.RequestException`` - # (transport-level failures like ConnectionError / Timeout / - # SSLError that bubble up unmodified from - # ``sess.send(initial_req)`` and don't inherit from - # RuntimeError). Both routes go through ``_classify_chunk_error`` - # so transient failures become resumable ``ChunkInterrupted`` - # subclasses; unknown failures re-raise to preserve their type. + def _issue(self, index: int, sub_args: dict[str, Any]) -> None: + """ + Issue one sub-request and record its ``(frame, response)`` pair + under ``index``. + + On failure, classify the exception and either wrap it as a + resumable :class:`ChunkInterrupted` carrying this call, or + re-raise it unchanged to preserve its type. Catches + ``RuntimeError`` (the layer's typed contract: + :class:`RateLimited`, :class:`ServiceUnavailable`, or the + mid-pagination wrapper), :class:`httpx.HTTPError` + (transport-level failures like ``ConnectError`` / + ``TimeoutException``), and :class:`httpx.InvalidURL` (which + inherits directly from ``Exception``, not ``HTTPError``); all + three feed :func:`_classify_chunk_error`. + """ try: - chunk = self.fetch_once(sub_args) - except (RuntimeError, requests.exceptions.RequestException) as exc: + self._chunks[index] = self.fetch_once(sub_args) + except (RuntimeError, httpx.HTTPError, httpx.InvalidURL) as exc: classification = _classify_chunk_error(exc) if classification is None: raise @@ -1166,31 +1143,13 @@ def _issue(self, sub_args: dict[str, Any]) -> None: total_chunks=self.plan.total, call=self, retry_after=retry_after, + cause=exc, ) from exc - self._chunks.append(chunk) - if len(self._chunks) < self.plan.total: - self._check_quota_remaining() - - def _check_quota_remaining(self) -> None: - if _quota_check_disabled(): - return - _, last_response = self._chunks[-1] - remaining = _read_remaining(last_response) - completed = len(self._chunks) - pending = self.plan.total - completed - if remaining is None or remaining >= pending: - return - raise RequestExceedsQuota( - planned_chunks=self.plan.total, - available=remaining + completed, - deficit=pending - remaining, - call=self, - ) def multi_value_chunked( *, - build_request: Callable[..., requests.PreparedRequest], + build_request: Callable[..., httpx.Request], url_limit: int | None = None, ) -> Callable[[_FetchOnce], _FetchOnce]: """ @@ -1204,15 +1163,15 @@ def multi_value_chunked( Parameters ---------- - build_request : Callable[..., requests.PreparedRequest] - Factory that turns a kwargs dict into a sized prepared - request, e.g. ``_construct_api_requests``. Called during - planning to measure each candidate plan. + build_request : Callable[..., httpx.Request] + Factory that turns a kwargs dict into a sized httpx request, + e.g. ``_construct_api_requests``. Called during planning to + measure each candidate plan. url_limit : int, optional - Byte budget for the prepared request (URL + body). When - ``None`` (default), the module-level - ``_WATERDATA_URL_BYTE_LIMIT`` is resolved at call time so test - patches via ``monkeypatch.setattr`` take effect. + Byte budget for the request (URL + body). When ``None`` + (default), the module-level ``_WATERDATA_URL_BYTE_LIMIT`` is + resolved at call time so test patches via + ``monkeypatch.setattr`` take effect. Returns ------- @@ -1225,9 +1184,6 @@ def multi_value_chunked( ------ RequestTooLarge If no plan can fit ``url_limit``. - RequestExceedsQuota - After the first sub-request, if the remaining plan can't fit - the current rate-limit window. ChunkInterrupted On a mid-execution 429 (:class:`QuotaExhausted`) or 5xx (:class:`ServiceInterrupted`). See :class:`ChunkedCall` for @@ -1243,9 +1199,10 @@ def decorator(fetch_once: _FetchOnce) -> _FetchOnce: @functools.wraps(fetch_once) def wrapper( args: dict[str, Any], - ) -> tuple[pd.DataFrame, requests.Response]: + ) -> tuple[pd.DataFrame, httpx.Response]: limit = _WATERDATA_URL_BYTE_LIMIT if url_limit is None else url_limit - return ChunkPlan(args, build_request, limit).execute(fetch_once) + plan = ChunkPlan(args, build_request, limit) + return plan.execute(fetch_once) return wrapper diff --git a/dataretrieval/waterdata/ratings.py b/dataretrieval/waterdata/ratings.py index a37c88b5..0e1b503d 100644 --- a/dataretrieval/waterdata/ratings.py +++ b/dataretrieval/waterdata/ratings.py @@ -17,10 +17,11 @@ from collections.abc import Iterable from typing import Any, Literal, get_args +import httpx import pandas as pd -import requests from dataretrieval.rdb import extract_rdb_comment, read_rdb +from dataretrieval.utils import HTTPX_DEFAULTS from .utils import ( _DURATION_RE, @@ -186,7 +187,7 @@ def get_ratings( fid = feature["id"] try: out[fid] = _download_and_parse(feature, file_path, ssl_check) - except (requests.RequestException, ValueError, OSError) as e: + except (httpx.HTTPError, ValueError, OSError) as e: logger.warning("Failed to download / parse %s: %s", fid, e) return out @@ -240,11 +241,12 @@ def _search( if bbox is not None: params["bbox"] = ",".join(map(str, bbox)) - response = requests.get( + response = httpx.get( f"{STAC_URL}/search", params=params, headers=_default_headers(), verify=ssl_check, + **HTTPX_DEFAULTS, ) response.raise_for_status() return response.json().get("features", []) @@ -257,7 +259,9 @@ def _download_and_parse( ) -> pd.DataFrame: """Fetch the feature's data asset, parse RDB, optionally persist to disk.""" url = feature["assets"]["data"]["href"] - response = requests.get(url, headers=_default_headers(), verify=ssl_check) + response = httpx.get( + url, headers=_default_headers(), verify=ssl_check, **HTTPX_DEFAULTS + ) response.raise_for_status() if file_path is not None: diff --git a/dataretrieval/waterdata/utils.py b/dataretrieval/waterdata/utils.py index dd908143..66ed1723 100644 --- a/dataretrieval/waterdata/utils.py +++ b/dataretrieval/waterdata/utils.py @@ -1,28 +1,34 @@ from __future__ import annotations import copy +import functools import json import logging import os import re -from collections.abc import Callable, Iterable, Iterator, Mapping +from collections.abc import ( + Callable, + Iterable, + Iterator, + Mapping, +) from contextlib import contextmanager from datetime import datetime, timedelta from typing import Any, TypeVar, get_args from zoneinfo import ZoneInfo +import httpx import pandas as pd -import requests -from requests.structures import CaseInsensitiveDict from dataretrieval import __version__ -from dataretrieval.utils import BaseMetadata +from dataretrieval.utils import HTTPX_DEFAULTS, BaseMetadata from dataretrieval.waterdata import _progress, chunking from dataretrieval.waterdata.chunking import ( _QUOTA_HEADER, RateLimited, ServiceUnavailable, - get_active_session, + _safe_elapsed, + get_active_client, ) from dataretrieval.waterdata.types import ( PROFILE_LOOKUP, @@ -366,24 +372,26 @@ def _check_ogc_requests(endpoint: str = "daily", req_type: str = "queryables"): ------ ValueError If req_type is not "queryables" or "schema". - requests.HTTPError - If the HTTP request returns an unsuccessful status code. + RateLimited, ServiceUnavailable, RuntimeError + From :func:`_raise_for_non_200` on any non-200 — same typed + contract as the main data path so callers can use one + ``except`` clause everywhere. """ if req_type not in ("queryables", "schema"): raise ValueError(f"req_type must be 'queryables' or 'schema', got {req_type!r}") url = f"{OGC_API_URL}/collections/{endpoint}/{req_type}" - resp = requests.get(url, headers=_default_headers()) - resp.raise_for_status() + resp = httpx.get(url, headers=_default_headers(), **HTTPX_DEFAULTS) + _raise_for_non_200(resp) return resp.json() -def _error_body(resp: requests.Response): +def _error_body(resp: httpx.Response): """ Build an informative error message from an HTTP response. Parameters ---------- - resp : requests.Response + resp : httpx.Response The HTTP response object to extract the error message from. Returns @@ -418,7 +426,7 @@ def _error_body(resp: requests.Response): j_txt = resp.json() except ValueError: snippet = (resp.text or "").strip()[:200] - reason = resp.reason or "Error" + reason = resp.reason_phrase or "Error" if snippet: return f"{status}: {reason}. {snippet}" return f"{status}: {reason}." @@ -459,18 +467,18 @@ def _parse_retry_after(value: str | None) -> float | None: return None -def _raise_for_non_200(resp: requests.Response) -> None: +def _raise_for_non_200(resp: httpx.Response) -> None: """ Raise a typed exception for any non-200 response. Routes through :func:`_error_body` (USGS-API-aware: handles 429/403 specially, extracts ``code``/``description`` from JSON error bodies) rather than ``Response.raise_for_status``, which - raises ``HTTPError`` with a generic message. + raises ``HTTPStatusError`` with a generic message. Parameters ---------- - resp : requests.Response + resp : httpx.Response The HTTP response to inspect. Raises @@ -520,7 +528,7 @@ def _paginated_failure_message(pages_collected: int, cause: BaseException) -> st and ``get_stats_data`` raise from the original exception. """ cause_str = str(cause).removesuffix(".") - # Some ``requests`` exceptions (e.g. ``Timeout()`` with no args) + # Some ``httpx`` exceptions (e.g. ``TimeoutException()`` with no args) # stringify to empty; fall back to the class name so the # returned message is always informative. if not cause_str.strip(): @@ -544,7 +552,7 @@ def _construct_api_requests( limit: int | None = None, skip_geometry: bool = False, **kwargs, -): +) -> httpx.Request: """ Constructs an HTTP request object for the specified water data API service. @@ -572,7 +580,7 @@ def _construct_api_requests( Returns ------- - requests.PreparedRequest + httpx.Request The constructed HTTP request object ready to be sent. Notes @@ -626,25 +634,23 @@ def _construct_api_requests( if post_params: headers["Content-Type"] = "application/query-cql-json" - request = requests.Request( + return httpx.Request( method="POST", url=service_url, headers=headers, - data=_cql2_param(post_params), - params=params, - ) - else: - request = requests.Request( - method="GET", - url=service_url, - headers=headers, + content=_cql2_param(post_params), params=params, ) - return request.prepare() + return httpx.Request( + method="GET", + url=service_url, + headers=headers, + params=params, + ) def _next_req_url( - resp: requests.Response, *, body: dict[str, Any] | None = None + resp: httpx.Response, *, body: dict[str, Any] | None = None ) -> str | None: """ Extracts the URL for the next page of results from an HTTP response from a @@ -652,7 +658,7 @@ def _next_req_url( Parameters ---------- - resp : requests.Response + resp : httpx.Response The HTTP response object containing JSON data and headers. body : dict, optional Pre-parsed JSON body for ``resp``. When provided, skips the @@ -676,13 +682,38 @@ def _next_req_url( if not body.get("numberReturned"): return None for link in body.get("links", []): - if link.get("rel") == "next": - return link.get("href") + if link.get("rel") != "next": + continue + href = link.get("href") + if not href: + return href + # Refuse to follow a next-page link to a different host — + # the request's headers/auth were minted for the original + # host and shouldn't leak to whatever a poisoned response + # body might supply. Guarded against mock-shaped ``resp.url`` + # attributes (tests sometimes set strings or ``MagicMock``) + # by falling open when host extraction isn't reliable. + try: + next_host = httpx.URL(href).host + resp_url = ( + resp.url + if isinstance(resp.url, httpx.URL) + else httpx.URL(str(resp.url)) + ) + cur_host = resp_url.host + except (httpx.InvalidURL, TypeError): + next_host = cur_host = None + if next_host and cur_host and next_host != cur_host: + raise RuntimeError( + f"Refusing to follow cross-host next-page URL: " + f"{next_host} != {cur_host}" + ) + return href return None def _get_resp_data( - resp: requests.Response, + resp: httpx.Response, geopd: bool, *, body: dict[str, Any] | None = None, @@ -692,7 +723,7 @@ def _get_resp_data( Parameters ---------- - resp : requests.Response + resp : httpx.Response The HTTP response object expected to contain a JSON body with a "features" key. geopd : bool @@ -769,47 +800,48 @@ def _get_resp_data( @contextmanager -def _session(client: requests.Session | None) -> Iterator[requests.Session]: +def _client_for(client: httpx.Client | None) -> Iterator[httpx.Client]: """ - Yield a usable session, picking the best available source. + Yield a usable client, picking the best available source. Resolution order: 1. ``client`` if the caller supplied one (borrowed; not closed here — the caller owns its lifecycle). - 2. The chunker's shared session if we're inside a ``ChunkedCall`` - fan-out (per :func:`chunking.get_active_session`). Borrowed; + 2. The chunker's shared client if we're inside a + ``ChunkedCall.resume()`` block (per + :func:`chunking.get_active_client`). Borrowed; ``ChunkedCall.resume`` closes it on exit. - 3. A fresh short-lived ``requests.Session`` opened here and closed + 3. A fresh short-lived ``httpx.Client`` opened here and closed on context exit. Parameters ---------- - client : requests.Session or None - A caller-owned session to borrow, or ``None`` to defer to the - chunker's shared session or a temporary one. + client : httpx.Client or None + A caller-owned client to borrow, or ``None`` to defer to the + chunker's shared client or a temporary one. Yields ------ - requests.Session - The chosen session. + httpx.Client + The chosen client. """ if client is not None: yield client return - shared = get_active_session() + shared = get_active_client() if shared is not None: yield shared return - with requests.Session() as new: + with httpx.Client(**HTTPX_DEFAULTS) as new: yield new def _aggregate_paginated_response( - initial: requests.Response, - last: requests.Response, + initial: httpx.Response, + last: httpx.Response, total_elapsed: timedelta, -) -> requests.Response: +) -> httpx.Response: """ Build a single response covering a paginated call. @@ -823,27 +855,24 @@ def _aggregate_paginated_response( Parameters ---------- - initial : requests.Response + initial : httpx.Response First-page response (the canonical one for ``md.url``). - last : requests.Response + last : httpx.Response Last-page response — supplies the headers to copy over. total_elapsed : datetime.timedelta Cumulative wall-clock across every page, including ``initial``. Returns ------- - requests.Response + httpx.Response A shallow copy of ``initial`` with ``.headers`` set to a fresh - ``CaseInsensitiveDict`` and ``.elapsed`` set to the - cumulative wall-clock. ``initial.headers`` / ``initial.elapsed`` - are never mutated, so callers holding a pre-pagination - reference still see the original first-page values. Other - ``Response`` fields (``_content``, ``raw``, ``cookies``, - ``request``) are still aliased to ``initial`` by the shallow - copy — callers that mutate those will affect ``initial``. + ``httpx.Headers`` and ``.elapsed`` set to the cumulative + wall-clock. ``initial.headers`` / ``initial.elapsed`` are + never mutated, so callers holding a pre-pagination reference + still see the original first-page values. """ final = copy.copy(initial) - final.headers = CaseInsensitiveDict(last.headers) + final.headers = httpx.Headers(last.headers) final.elapsed = total_elapsed return final @@ -852,12 +881,12 @@ def _aggregate_paginated_response( def _paginate( - initial_req: requests.PreparedRequest, + initial_req: httpx.Request, *, - parse_response: Callable[[requests.Response], tuple[pd.DataFrame, _Cursor | None]], - follow_up: Callable[[_Cursor, requests.Session], requests.Response], - client: requests.Session | None = None, -) -> tuple[pd.DataFrame, requests.Response]: + parse_response: Callable[[httpx.Response], tuple[pd.DataFrame, _Cursor | None]], + follow_up: Callable[[_Cursor, httpx.Client], httpx.Response], + client: httpx.Client | None = None, +) -> tuple[pd.DataFrame, httpx.Response]: """ Drive a paginated request to completion. @@ -870,27 +899,27 @@ def _paginate( Parameters ---------- - initial_req : requests.PreparedRequest + initial_req : httpx.Request First-page request to send. parse_response : callable ``resp -> (df, next_cursor_or_None)``. Returns the page's DataFrame and the cursor (URL, token, …) used to drive ``follow_up`` for the next page; ``None`` terminates the loop. follow_up : callable - ``(cursor, session) -> requests.Response``. Builds and sends + ``(cursor, client) -> httpx.Response``. Builds and sends the next-page request. - client : requests.Session, optional - Caller-borrowed session. ``None`` (default) means use the - chunker's shared session (if inside a chunked call) or open + client : httpx.Client, optional + Caller-borrowed client. ``None`` (default) means use the + chunker's shared client (if inside a chunked call) or open a temporary one. Returns ------- df : pandas.DataFrame Concatenation of every page's parsed frame. - response : requests.Response + response : httpx.Response A shallow copy of the first-page response, with ``.headers`` - rebuilt as a fresh ``CaseInsensitiveDict`` reflecting the last + rebuilt as a fresh ``httpx.Headers`` reflecting the last page and ``.elapsed`` set to cumulative wall-clock. The canonical URL is preserved from the first page. The original first-page response is not mutated. @@ -906,22 +935,22 @@ def _paginate( (wrapped via :func:`_paginated_failure_message` with the original exception on ``__cause__``), or any failure on a subsequent page (same wrapping). - requests.exceptions.RequestException + httpx.HTTPError Network-level failures on the *initial* request (e.g. - ``ConnectionError``, ``Timeout``) propagate unmodified so - callers can branch on the specific type; equivalent failures - on subsequent pages are wrapped per above. + ``ConnectError``, ``TimeoutException``) propagate unmodified + so callers can branch on the specific type; equivalent + failures on subsequent pages are wrapped per above. """ logger.debug("Requesting: %s", initial_req.url) reporter = _progress.current() - with _session(client) as sess: - resp = sess.send(initial_req) + with _client_for(client) as client: + resp = client.send(initial_req) _raise_for_non_200(resp) # Keep the original-request response as the "canonical" one for # ``md.url`` reproducibility; ``.headers`` and ``.elapsed`` get # overwritten with latest/cumulative values below. initial_response = resp - total_elapsed = resp.elapsed + total_elapsed = _safe_elapsed(resp) try: df, cursor = parse_response(resp) @@ -941,11 +970,11 @@ def _paginate( reporter.add_page(rows=len(df)) while cursor is not None: try: - resp = follow_up(cursor, sess) + resp = follow_up(cursor, client) _raise_for_non_200(resp) df, cursor = parse_response(resp) dfs.append(df) - total_elapsed += resp.elapsed + total_elapsed += _safe_elapsed(resp) if reporter is not None: reporter.set_rate_remaining( resp.headers.get(_QUOTA_HEADER), @@ -969,11 +998,27 @@ def _paginate( return pd.concat(dfs, ignore_index=True), final_response +def _ogc_parse_response( + resp: httpx.Response, *, geopd: bool +) -> tuple[pd.DataFrame, str | None]: + """Parse one OGC API page: extract the DataFrame and the next-page URL. + + Coerces falsy cursors (empty href, etc.) to ``None`` so the + paginate loop's ``while cursor is not None`` terminates instead + of spinning on a meaningless value. + """ + body = resp.json() + return ( + _get_resp_data(resp, geopd=geopd, body=body), + _next_req_url(resp, body=body) or None, + ) + + def _walk_pages( geopd: bool, - req: requests.PreparedRequest, - client: requests.Session | None = None, -) -> tuple[pd.DataFrame, requests.Response]: + req: httpx.Request, + client: httpx.Client | None = None, +) -> tuple[pd.DataFrame, httpx.Response]: """ Iterate through paginated OGC API responses and aggregate into one DataFrame. @@ -987,17 +1032,17 @@ def _walk_pages( ---------- geopd : bool Whether geopandas is installed (drives geometry handling). - req : requests.PreparedRequest + req : httpx.Request The initial HTTP request to send. - client : requests.Session, optional - Caller-borrowed session; ``None`` defers session management to + client : httpx.Client, optional + Caller-borrowed client; ``None`` defers client management to :func:`_paginate`. Returns ------- pd.DataFrame A DataFrame containing the aggregated results from all pages. - requests.Response + httpx.Response Aggregated response — initial-request URL (for query identity), final page's headers (so downstream sees current rate-limit state), and cumulative ``elapsed`` summed across pages. @@ -1006,29 +1051,19 @@ def _walk_pages( ------ RuntimeError See :func:`_paginate`. - requests.exceptions.RequestException + httpx.HTTPError See :func:`_paginate`. """ - method = req.method # ``PreparedRequest.method`` is already upper-cased. - headers = dict(req.headers) - content = req.body if method == "POST" else None + method = req.method # ``httpx.Request.method`` is already upper-cased. + headers = req.headers + content = req.content if method == "POST" else None - def parse_response(resp: requests.Response) -> tuple[pd.DataFrame, str | None]: - body = resp.json() - # Coerce falsy cursors (empty href, etc.) to None so - # _paginate's `while cursor is not None` terminates instead of - # spinning on a meaningless value. - return ( - _get_resp_data(resp, geopd=geopd, body=body), - _next_req_url(resp, body=body) or None, - ) - - def follow_up(cursor: str, sess: requests.Session) -> requests.Response: - return sess.request(method, cursor, headers=headers, data=content) + def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: + return client.request(method, cursor, headers=headers, content=content) return _paginate( req, - parse_response=parse_response, + parse_response=functools.partial(_ogc_parse_response, geopd=geopd), follow_up=follow_up, client=client, ) @@ -1255,18 +1290,20 @@ def get_ogc_data( return return_list, BaseMetadata(response) -@chunking.multi_value_chunked(build_request=_construct_api_requests) +@chunking.multi_value_chunked( + build_request=_construct_api_requests, +) def _fetch_once( args: dict[str, Any], -) -> tuple[pd.DataFrame, requests.Response]: +) -> tuple[pd.DataFrame, httpx.Response]: """Send one prepared-args OGC request; return the frame + response. ``@chunking.multi_value_chunked`` models every multi-value list parameter and the cql-text filter as a chunkable axis, greedy-halves the biggest chunk across all axes until each sub-request URL fits, and iterates the cartesian product. With no chunkable inputs the - decorator passes args through unchanged. Either way the return - shape is ``(frame, response)``. + decorator passes args through unchanged. The return shape + is ``(frame, response)``. """ req = _construct_api_requests(**args) return _walk_pages(geopd=GEOPANDAS, req=req) @@ -1432,7 +1469,7 @@ def get_stats_data( args: dict[str, Any], service: str, expand_percentiles: bool, - client: requests.Session | None = None, + client: httpx.Client | None = None, ) -> tuple[pd.DataFrame, BaseMetadata]: """ Retrieves statistical data from a specified endpoint and returns it @@ -1464,27 +1501,26 @@ def get_stats_data( """ url = f"{STATISTICS_API_URL}/{service}" - request = requests.Request( + req = httpx.Request( method="GET", url=url, headers=_default_headers(), params=args, ) - req = request.prepare() - method = req.method # ``PreparedRequest.method`` is already upper-cased. - headers = dict(req.headers) + method = req.method + headers = req.headers - def parse_response(resp: requests.Response) -> tuple[pd.DataFrame, str | None]: + def parse_response(resp: httpx.Response) -> tuple[pd.DataFrame, str | None]: body = resp.json() # Coerce falsy cursors ("", 0) to None so _paginate terminates. # USGS uses "next": null at end-of-stream, but defensive coerce # protects against any "" sentinel a future schema might use. return _handle_stats_nesting(body, geopd=GEOPANDAS), body.get("next") or None - def follow_up(cursor: str, sess: requests.Session) -> requests.Response: - # Build a fresh params dict per page so the caller's ``args`` is - # never mutated. - return sess.request( + def follow_up(cursor: str, client: httpx.Client) -> httpx.Response: + # Build a fresh params dict per page so the caller's ``args`` + # is never mutated. + return client.request( method, url=url, params={**args, "next_token": cursor}, headers=headers ) diff --git a/dataretrieval/wqp.py b/dataretrieval/wqp.py index 8cfc6ca1..e874f0be 100644 --- a/dataretrieval/wqp.py +++ b/dataretrieval/wqp.py @@ -625,7 +625,7 @@ class WQP_Metadata(BaseMetadata): Response url query_time : datetme.timedelta Response elapsed time - header : requests.structures.CaseInsensitiveDict + header : httpx.Headers Response headers comments : None Metadata comments. WQP does not return comments. @@ -640,7 +640,7 @@ def __init__(self, response, **parameters) -> None: Parameters ---------- response : Response - Response object from requests module + Response object from httpx module parameters : dict Unpacked dictionary of the parameters supplied in the request diff --git a/pyproject.toml b/pyproject.toml index 5c9fbda0..65b1ae68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers = [ "Programming Language :: Python :: 3", ] dependencies = [ - "requests", + "httpx", "pandas>=2.0.0,<4.0.0", ] dynamic = ["version"] @@ -37,7 +37,7 @@ test = [ "pytest-cov[all]", "pytest-rerunfailures", "coverage", - "requests-mock", + "pytest-httpx", "ruff", ] doc = [ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..afbdfec2 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,32 @@ +""" +Test scaffolding for the dataretrieval test suite. + +Relaxes ``pytest-httpx``'s strict-mode flags so unconsumed mocks and +unmatched real requests don't fail the suite (matches the historical +``requests-mock``-style permissiveness the test code was written +against, and keeps mocked-URL setup terse). +""" + +from __future__ import annotations + +import pytest + + +def pytest_collection_modifyitems(config, items): + """Apply relaxed ``pytest-httpx`` strict-mode settings to every + test in the suite — matches the permissive defaults the historical + tests were written against.""" + marker = pytest.mark.httpx_mock( + assert_all_responses_were_requested=False, + assert_all_requests_were_expected=False, + can_send_already_matched_responses=True, + ) + for item in items: + item.add_marker(marker) + + +@pytest.fixture +def non_mocked_hosts() -> list[str]: + """No hosts are exempted from mocking; every HTTP call must hit + a mock registered through the ``httpx_mock`` fixture.""" + return [] diff --git a/tests/nldi_test.py b/tests/nldi_test.py index 91be5026..988d9672 100644 --- a/tests/nldi_test.py +++ b/tests/nldi_test.py @@ -18,7 +18,7 @@ def _reset_data_source_cache(monkeypatch): monkeypatch.setattr(nldi, "_AVAILABLE_DATA_SOURCES", None) -def mock_request_data_sources(requests_mock): +def mock_request_data_sources(httpx_mock): request_url = f"{NLDI_API_BASE_URL}/" available_data_sources = [ {"source": "ca_gages"}, @@ -38,42 +38,48 @@ def mock_request_data_sources(requests_mock): {"source": "WQP"}, {"source": "comid"}, ] - requests_mock.get( - request_url, json=available_data_sources, headers={"mock_header": "value"} + httpx_mock.add_response( + method="GET", + url=request_url, + json=available_data_sources, + headers={"mock_header": "value"}, ) -def mock_request(requests_mock, request_url, file_path): +def mock_request(httpx_mock, request_url, file_path): with open(file_path) as text: - requests_mock.get( - request_url, text=text.read(), headers={"mock_header": "value"} + httpx_mock.add_response( + method="GET", + url=request_url, + text=text.read(), + headers={"mock_header": "value"}, ) -def test_get_basin(requests_mock): +def test_get_basin(httpx_mock): """Tests NLDI get basin query""" request_url = ( f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/basin" f"?simplified=true&splitCatchment=false" ) response_file_path = "tests/data/nldi_get_basin.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_basin(feature_source="WQP", feature_id="USGS-054279485") assert isinstance(gdf, GeoDataFrame) assert gdf.size == 1 -def test_get_flowlines(requests_mock): +def test_get_flowlines(httpx_mock): """Tests NLDI get flowlines query using feature source as the origin""" request_url = ( f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/navigation/UM/flowlines" f"?distance=5&trimStart=false" ) response_file_path = "tests/data/nldi_get_flowlines.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_flowlines( feature_source="WQP", feature_id="USGS-054279485", navigation_mode="UM" @@ -82,21 +88,22 @@ def test_get_flowlines(requests_mock): assert gdf.size == 2 -def test_get_flowlines_by_comid(requests_mock): +def test_get_flowlines_by_comid(httpx_mock): """Tests NLDI get flowlines query using comid as the origin""" request_url = ( - f"{NLDI_API_BASE_URL}/comid/13294314/navigation/UM/flowlines?distance=50" + f"{NLDI_API_BASE_URL}/comid/13294314/navigation/UM/flowlines" + "?distance=50&trimStart=false" ) response_file_path = "tests/data/nldi_get_flowlines_by_comid.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_flowlines(navigation_mode="UM", comid=13294314, distance=50) assert isinstance(gdf, GeoDataFrame) assert gdf.size == 16 -def test_features_by_feature_source_with_navigation(requests_mock): +def test_features_by_feature_source_with_navigation(httpx_mock): """Tests NLDI get features query using feature source as the origin with navigation mode """ @@ -106,8 +113,8 @@ def test_features_by_feature_source_with_navigation(requests_mock): response_file_path = ( "tests/data/nldi_get_features_by_feature_source_with_nav_mode.json" ) - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_features( feature_source="WQP", @@ -120,7 +127,7 @@ def test_features_by_feature_source_with_navigation(requests_mock): assert gdf.size == 108 -def test_features_by_feature_source_without_navigation(requests_mock): +def test_features_by_feature_source_without_navigation(httpx_mock): """Tests NLDI get features query using feature source as the origin without navigation mode """ @@ -128,20 +135,20 @@ def test_features_by_feature_source_without_navigation(requests_mock): response_file_path = ( "tests/data/nldi_get_features_by_feature_source_without_nav_mode.json" ) - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_features(feature_source="WQP", feature_id="USGS-054279485") assert isinstance(gdf, GeoDataFrame) assert gdf.size == 10 -def test_get_features_by_comid(requests_mock): +def test_get_features_by_comid(httpx_mock): """Tests NLDI get features query using comid as the origin""" request_url = f"{NLDI_API_BASE_URL}/comid/13294314/navigation/UM/WQP?distance=5" response_file_path = "tests/data/nldi_get_features_by_comid.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_features( comid=13294314, data_source="WQP", navigation_mode="UM", distance=5 @@ -150,26 +157,29 @@ def test_get_features_by_comid(requests_mock): assert gdf.size == 405 -def test_get_features_by_lat_long(requests_mock): +def test_get_features_by_lat_long(httpx_mock): """Tests NLDI get features query using lat/long as the origin""" request_url = ( f"{NLDI_API_BASE_URL}/comid/position?coords=POINT%28-89.509%2043.087%29" ) response_file_path = "tests/data/nldi_get_features_by_lat_long.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) gdf = get_features(lat=43.087, long=-89.509) assert isinstance(gdf, GeoDataFrame) assert gdf.size == 6 -def test_search_for_basin(requests_mock): +def test_search_for_basin(httpx_mock): """Tests NLDI search query for basin""" - request_url = f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/basin" + request_url = ( + f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/basin" + "?simplified=true&splitCatchment=false" + ) response_file_path = "tests/data/nldi_get_basin.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search( feature_source="WQP", feature_id="USGS-054279485", find="basin" @@ -180,12 +190,15 @@ def test_search_for_basin(requests_mock): assert len(search_results["features"][0]["geometry"]["coordinates"][0]) == 122 -def test_search_for_flowlines(requests_mock): +def test_search_for_flowlines(httpx_mock): """Tests NLDI search query for flowlines""" - request_url = f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/navigation/UM/flowlines" + request_url = ( + f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/navigation/UM/flowlines" + "?distance=50&trimStart=false" + ) response_file_path = "tests/data/nldi_get_flowlines.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search( feature_source="WQP", @@ -199,12 +212,15 @@ def test_search_for_flowlines(requests_mock): assert len(search_results["features"][0]["geometry"]["coordinates"]) == 27 -def test_search_for_flowlines_by_comid(requests_mock): +def test_search_for_flowlines_by_comid(httpx_mock): """Tests NLDI search query for flowlines by comid""" - request_url = f"{NLDI_API_BASE_URL}/comid/13294314/navigation/UM/flowlines" + request_url = ( + f"{NLDI_API_BASE_URL}/comid/13294314/navigation/UM/flowlines" + "?distance=50&trimStart=false" + ) response_file_path = "tests/data/nldi_get_flowlines_by_comid.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search(comid=13294314, navigation_mode="UM", find="flowlines") assert isinstance(search_results, dict) @@ -213,7 +229,7 @@ def test_search_for_flowlines_by_comid(requests_mock): assert len(search_results["features"][0]["geometry"]["coordinates"]) == 27 -def test_search_for_features_by_feature_source_with_navigation(requests_mock): +def test_search_for_features_by_feature_source_with_navigation(httpx_mock): """Tests NLDI search query for features by feature source""" request_url = ( f"{NLDI_API_BASE_URL}/WQP/USGS-054279485/navigation/UM/nwissite?distance=50" @@ -221,8 +237,8 @@ def test_search_for_features_by_feature_source_with_navigation(requests_mock): response_file_path = ( "tests/data/nldi_get_features_by_feature_source_with_nav_mode.json" ) - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search( feature_source="WQP", @@ -237,14 +253,14 @@ def test_search_for_features_by_feature_source_with_navigation(requests_mock): assert len(search_results["features"]) == 9 -def test_search_for_features_by_feature_source_without_navigation(requests_mock): +def test_search_for_features_by_feature_source_without_navigation(httpx_mock): """Tests NLDI search query for features by feature source""" request_url = f"{NLDI_API_BASE_URL}/WQP/USGS-054279485" response_file_path = ( "tests/data/nldi_get_features_by_feature_source_without_nav_mode.json" ) - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search( feature_source="WQP", feature_id="USGS-054279485", find="features" @@ -255,12 +271,12 @@ def test_search_for_features_by_feature_source_without_navigation(requests_mock) assert len(search_results["features"]) == 1 -def test_search_for_features_by_comid(requests_mock): +def test_search_for_features_by_comid(httpx_mock): """Tests NLDI search query for features by comid""" request_url = f"{NLDI_API_BASE_URL}/comid/13294314/navigation/UM/WQP?distance=5" response_file_path = "tests/data/nldi_get_features_by_comid.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search( comid=13294314, @@ -275,14 +291,14 @@ def test_search_for_features_by_comid(requests_mock): assert len(search_results["features"]) == 45 -def test_search_for_features_by_lat_long(requests_mock): +def test_search_for_features_by_lat_long(httpx_mock): """Tests NLDI search query for features by lat/long""" request_url = ( f"{NLDI_API_BASE_URL}/comid/position?coords=POINT%28-89.509%2043.087%29" ) response_file_path = "tests/data/nldi_get_features_by_lat_long.json" - mock_request_data_sources(requests_mock) - mock_request(requests_mock, request_url, response_file_path) + mock_request_data_sources(httpx_mock) + mock_request(httpx_mock, request_url, response_file_path) search_results = search(lat=43.087, long=-89.509, find="features") assert isinstance(search_results, dict) @@ -291,13 +307,13 @@ def test_search_for_features_by_lat_long(requests_mock): assert len(search_results["features"][0]["geometry"]["coordinates"]) == 27 -def test_validate_data_source_rejects_invalid_after_cache_populated(requests_mock): +def test_validate_data_source_rejects_invalid_after_cache_populated(httpx_mock): """Once the cache is warm, invalid data sources must still raise ValueError. Regression: previously the validation check was nested inside the cache-population branch, so all calls after the first silently passed. """ - mock_request_data_sources(requests_mock) + mock_request_data_sources(httpx_mock) nldi._validate_data_source("WQP") @@ -323,3 +339,49 @@ def test_validate_navigation_mode_raises_value_error_for_invalid(): def test_validate_navigation_mode_normalizes_lowercase(): """Regression: lowercase values used to validate but be sent unchanged.""" assert _validate_navigation_mode("um") == "UM" + + +def test_query_nldi_non_200_surfaces_reason_phrase(httpx_mock): + """``_query_nldi`` must include the response's reason phrase in + the raised ``ValueError``. Pre-fix this crashed with + ``AttributeError: 'Response' object has no attribute 'reason'`` + because the migration to httpx renamed ``.reason`` → + ``.reason_phrase`` but missed this call site.""" + httpx_mock.add_response( + method="GET", + url=f"{NLDI_API_BASE_URL}/WQP/USGS-MISSING/basin" + "?simplified=true&splitCatchment=false", + status_code=429, + ) + mock_request_data_sources(httpx_mock) + with pytest.raises(ValueError, match="Error reason:"): + nldi.get_basin(feature_source="WQP", feature_id="USGS-MISSING") + + +def test_validate_data_source_rejects_malformed_catalog(httpx_mock, monkeypatch): + """``_validate_data_source`` should raise ``ValueError`` with an + informative message if the NLDI base URL returns a non-list shape + (or a list whose entries don't carry ``source`` keys), instead of + crashing with ``TypeError: string indices must be integers``.""" + monkeypatch.setattr(nldi, "_AVAILABLE_DATA_SOURCES", None) + httpx_mock.add_response( + method="GET", + url=f"{NLDI_API_BASE_URL}/", + json={"error": "upstream maintenance"}, + ) + with pytest.raises(ValueError, match="unexpected shape"): + nldi._validate_data_source("WQP") + + +def test_query_504_raises_value_error(httpx_mock): + """``utils.query`` must classify 504 Gateway Timeout as a 5xx + failure. Pre-fix: the membership check ``[500, 502, 503]`` missed + 504 and returned the response unchanged, leading downstream + callers (e.g. ``_query_nldi``) to silently swallow the failure as + an empty dict via JSONDecodeError.""" + from dataretrieval.utils import query + + url = "https://example.invalid/x" + httpx_mock.add_response(method="GET", url=f"{url}?a=1", status_code=504) + with pytest.raises(ValueError, match="Service Unavailable: 504"): + query(url, {"a": "1"}) diff --git a/tests/nwis_test.py b/tests/nwis_test.py index a3b23da6..f343f26e 100644 --- a/tests/nwis_test.py +++ b/tests/nwis_test.py @@ -1,5 +1,6 @@ import datetime import json +import re import warnings from pathlib import Path from unittest import mock @@ -37,7 +38,7 @@ def _load_mock_json(file_name): return json.load(f) -def _test_iv_service(requests_mock): +def _test_iv_service(httpx_mock): """Mocked test of instantaneous value service""" start = START_DATE end = END_DATE @@ -48,17 +49,17 @@ def _test_iv_service(requests_mock): mock_json = _load_mock_json("nwis_iv_mock.json") # Match the base URL and ensure query parameters are correct - requests_mock.get( - "https://waterservices.usgs.gov/nwis/iv", + httpx_mock.add_response( + method="GET", + url=re.compile(r"^https://waterservices\.usgs\.gov/nwis/iv(\?.*)?$"), json=mock_json, - complete_qs=False, ) return get_record(site, start, end, service=service) -def test_iv_service_answer(requests_mock): - df = _test_iv_service(requests_mock) +def test_iv_service_answer(httpx_mock): + df = _test_iv_service(httpx_mock) # check multiindex function assert df.index.names == [ SITENO_COL, @@ -152,23 +153,25 @@ def test_warn_message_includes_replacement(self, func_name, replacement_substrin assert replacement_substring in message assert _NWIS_REMOVAL_DATE in message - def test_get_iv_fires_deprecation_on_call(self, requests_mock): + def test_get_iv_fires_deprecation_on_call(self, httpx_mock): """End-to-end: a real call routes through _warn_deprecated.""" - requests_mock.get( - "https://waterservices.usgs.gov/nwis/iv", + httpx_mock.add_response( + method="GET", + url=re.compile(r"^https://waterservices\.usgs\.gov/nwis/iv(\?.*)?$"), json={"value": {"timeSeries": []}}, ) with pytest.warns(DeprecationWarning, match="get_iv.*waterdata.get_continuous"): get_iv(sites="01491000") - def test_nested_calls_emit_one_warning(self, requests_mock): + def test_nested_calls_emit_one_warning(self, httpx_mock): """get_record(service='iv') wraps get_iv -> query_waterservices. Without re-entrancy suppression the user would see 3 near-identical deprecation warnings for one call; pin the outermost-only contract. """ - requests_mock.get( - "https://waterservices.usgs.gov/nwis/iv", + httpx_mock.add_response( + method="GET", + url=re.compile(r"^https://waterservices\.usgs\.gov/nwis/iv(\?.*)?$"), json={"value": {"timeSeries": []}}, ) with warnings.catch_warnings(record=True) as caught: @@ -318,7 +321,7 @@ def test_expandedrdb_get_info(self): assert "count_nu" not in data.columns -def test_empty_timeseries(requests_mock): +def test_empty_timeseries(httpx_mock): """Test based on empty case from GitHub Issue #26.""" sites = "011277906" start = "2010-07-20" @@ -326,10 +329,10 @@ def test_empty_timeseries(requests_mock): mock_json = _load_mock_json("nwis_iv_empty_mock.json") # Match the base URL and ensure query parameters are correct - requests_mock.get( - "https://waterservices.usgs.gov/nwis/iv", + httpx_mock.add_response( + method="GET", + url=re.compile(r"^https://waterservices\.usgs\.gov/nwis/iv(\?.*)?$"), json=mock_json, - complete_qs=False, ) df = get_record(sites=sites, service="iv", start=start, end=end) diff --git a/tests/utils_test.py b/tests/utils_test.py index 2c350b2b..c25e1084 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -19,12 +19,12 @@ def test_url_too_long(self): or abruptly close the connection (ConnectionError). Both are valid responses to an excessively long URL. """ - import requests as req + import httpx # all sites in MD sites, _ = nwis.what_sites(stateCd="MD") # raise error by trying to query them all, so URL is way too long - with pytest.raises((ValueError, req.exceptions.ConnectionError)): + with pytest.raises((ValueError, httpx.ConnectError)): nwis.get_iv(sites=sites.site_no.values.tolist()) def test_header(self): diff --git a/tests/waterdata_chunking_test.py b/tests/waterdata_chunking_test.py index d9a54a7d..21b23757 100644 --- a/tests/waterdata_chunking_test.py +++ b/tests/waterdata_chunking_test.py @@ -35,30 +35,35 @@ ChunkPlan, QuotaExhausted, RateLimited, - RequestExceedsQuota, RequestTooLarge, ServiceInterrupted, ServiceUnavailable, - _chunked_session, + _chunked_client, _extract_axes, - _read_remaining, multi_value_chunked, ) from dataretrieval.waterdata.utils import _construct_api_requests class _FakeReq: - __slots__ = ("url", "body") + """Stand-in for ``httpx.Request`` whose ``_request_bytes`` shape + is ``len(str(url)) + len(content)``.""" - def __init__(self, url, body=None): + __slots__ = ("url", "content") + + def __init__(self, url, content=b""): self.url = url - self.body = body + self.content = ( + content + if isinstance(content, (bytes, bytearray)) + else (content.encode("utf-8") if isinstance(content, str) else b"") + ) def _fake_build(*, base=200, **kwargs): """Fake build_request: URL length deterministic in its inputs. - Mirrors the GET-routed shape: payload goes in the URL, body is None. + Mirrors the GET-routed shape: payload goes in the URL, body is empty. List/string values are URL-encoded via ``quote_plus`` so the fake's byte count matches what the real ``_construct_api_requests`` would produce; otherwise an alphanumeric test could pass against the fake @@ -234,33 +239,6 @@ def fetch(args): assert calls[0]["monitoring_location_id"] == ["A", "B"] -def test_multi_value_chunked_emits_cartesian_product(): - """Two chunkable axes, each split into 2 chunks → exactly 4 sub-requests, - each pairing one chunk from each axis.""" - calls = [] - - @multi_value_chunked(build_request=_fake_build, url_limit=240) - def fetch(args): - calls.append({k: v for k, v in args.items() if k in ("sites", "pcodes")}) - return pd.DataFrame(), mock.Mock( - elapsed=datetime.timedelta(seconds=0.1), headers={} - ) - - fetch( - { - "sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10], - "pcodes": ["P1" * 10, "P2" * 10, "P3" * 10, "P4" * 10], - } - ) - # Both heavy → planner should split both axes. Confirm a cartesian shape: - # every unique site-chunk pairs with every unique pcode-chunk. - sites_seen = {tuple(c["sites"]) for c in calls} - pcodes_seen = {tuple(c["pcodes"]) for c in calls} - assert len(calls) == len(sites_seen) * len(pcodes_seen) - assert len(sites_seen) > 1 - assert len(pcodes_seen) > 1 - - def test_multi_value_chunked_emits_3d_cartesian_product(): """Three chunkable axes, each forced to split → exhaustive cartesian product across all three. Verifies the halving loop in @@ -332,20 +310,20 @@ def fetch(args): def test_chunked_session_shared_across_sub_requests(): """Every sub-request of one chunked call sees the same - ``requests.Session`` on the ``_chunked_session`` ContextVar, so + ``httpx.Client`` on the ``_chunked_client`` ContextVar, so downstream paginated helpers (``_walk_pages``) can reuse the connection pool instead of handshaking fresh on each sub-request.""" sessions_seen = [] @multi_value_chunked(build_request=_fake_build, url_limit=240) def fetch(args): - sessions_seen.append(_chunked_session.get()) + sessions_seen.append(_chunked_client.get()) return pd.DataFrame(), mock.Mock( elapsed=datetime.timedelta(seconds=0.1), headers={} ) # Outside a chunked call: no session published. - assert _chunked_session.get() is None + assert _chunked_client.get() is None fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) @@ -357,7 +335,7 @@ def fetch(args): # And it was the same object every time. assert len({id(s) for s in sessions_seen}) == 1 # On exit the ContextVar is reset to its default. - assert _chunked_session.get() is None + assert _chunked_client.get() is None def test_chunked_session_isolated_per_resume(): @@ -385,107 +363,22 @@ def fetch(args): fetch({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) # First resume's session is closed; ContextVar is reset. - assert _chunked_session.get() is None + assert _chunked_client.get() is None state["blow_up"] = False excinfo.value.call.resume() # Second resume's session is also cleaned up. - assert _chunked_session.get() is None + assert _chunked_client.get() is None def _quota_response(remaining: int | str | None) -> mock.Mock: - """A mock requests.Response-like object whose ``x-ratelimit-remaining`` + """A mock httpx.Response-like object whose ``x-ratelimit-remaining`` header reflects the given value (None → header absent).""" resp = mock.Mock(elapsed=datetime.timedelta(seconds=0.1)) resp.headers = {} if remaining is None else {_QUOTA_HEADER: str(remaining)} return resp -def test_read_remaining_parses_header(): - assert _read_remaining(_quota_response(42)) == 42 - - -def test_read_remaining_returns_none_when_header_missing(): - """No rate-limit header → ``None`` so ``ChunkedCall`` can branch - on ``is None`` instead of comparing against a magic sentinel.""" - assert _read_remaining(_quota_response(None)) is None - - -def test_read_remaining_returns_none_on_malformed_header(): - """Non-integer header value → ``None`` so a parse failure doesn't - trip the quota check.""" - assert _read_remaining(_quota_response("not-a-number")) is None - - -def test_request_exceeds_quota_after_first_chunk(): - """Plan totals 4 sub-requests. The first response reports - ``x-ratelimit-remaining=1`` — only 2 sub-requests fit total - (the one just issued + 1 more). The wrapper must raise - ``RequestExceedsQuota`` *before* issuing chunk 2, and the - exception must carry a ``.call`` handle so the first chunk's - already-fetched data is recoverable.""" - calls: list[dict] = [] - - def fetch(args): - calls.append(args) - return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(1) - - decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - - with pytest.raises(RequestExceedsQuota) as excinfo: - decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - - err = excinfo.value - assert err.planned_chunks == 4 - assert err.available == 2 # remaining=1 + the chunk we just spent - assert err.deficit == 2 - assert len(calls) == 1, "only the first chunk should have been issued" - # The originating ChunkedCall is exposed on .call so the first - # chunk's already-fetched data is recoverable. - assert err.call is not None - assert err.call.completed_chunks == 1 - assert not err.call.partial_frame.empty - - -def test_request_exceeds_quota_message_reports_deficit(): - """The error must surface planned / available / deficit so callers - know precisely how far over budget the call is.""" - e = RequestExceedsQuota(planned_chunks=10, available=4, deficit=6) - msg = str(e) - assert "10" in msg - assert "4" in msg - assert "6" in msg - - -def test_request_exceeds_quota_not_raised_when_plan_fits(): - """If ``x-ratelimit-remaining`` is large enough to cover the rest - of the plan, ``ChunkedCall`` proceeds normally.""" - remaining_seq = iter([100, 99, 98, 97]) - - def fetch(args): - return ( - pd.DataFrame({"sites": list(args["sites"])}), - _quota_response(next(remaining_seq)), - ) - - decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - df, _ = decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - assert len(df) == 4 - - -def test_no_quota_check_when_header_absent(): - """Without an ``x-ratelimit-remaining`` header ``ChunkedCall`` - has no quota signal and must NOT synthesize a - ``RequestExceedsQuota``; every planned sub-request runs.""" - - def fetch(args): - return pd.DataFrame({"sites": list(args["sites"])}), _quota_response(None) - - decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - df, _ = decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - assert len(df) == 4 - - def test_quota_exhausted_on_mid_call_429(): """Mid-call 429 (a concurrent caller drained the window) surfaces as ``QuotaExhausted`` carrying the partial frame plus the chunk @@ -757,13 +650,13 @@ def test_chunk_interrupted_base_class_catches_both(): def test_connection_error_wrapped_as_service_interrupted(): - """A bare ``requests.exceptions.ConnectionError`` (or any other - transport-level RequestException) doesn't inherit from - ``RuntimeError``; without the widened catch in ``_issue`` it - would escape uncaught and the user would lose the resumable - handle to ``.call.resume()``. Verify ``ChunkedCall`` wraps it as - ``ServiceInterrupted`` so partial progress is preserved.""" - import requests as _requests + """A bare ``httpx.ConnectError`` (or any other transport-level + ``httpx.HTTPError``) doesn't inherit from ``RuntimeError``; + without the widened catch in ``_issue`` it would escape uncaught + and the user would lose the resumable handle to ``.call.resume()``. + Verify ``ChunkedCall`` wraps it as ``ServiceInterrupted`` so + partial progress is preserved.""" + import httpx as _httpx state = {"i": 0, "blow_up": True} @@ -771,7 +664,7 @@ def fetch(args): i = state["i"] state["i"] += 1 if i == 2 and state["blow_up"]: - raise _requests.exceptions.ConnectionError("connection reset") + raise _httpx.ConnectError("connection reset") return ( pd.DataFrame({"sites": list(args["sites"])}), _quota_response(500), @@ -785,13 +678,51 @@ def fetch(args): assert err.completed_chunks == 2 assert err.call is not None # The transport exception is on __cause__ so callers can drill in if needed. - assert isinstance(err.__cause__, _requests.exceptions.ConnectionError) + assert isinstance(err.__cause__, _httpx.ConnectError) # Resume after the upstream recovers. state["blow_up"] = False df, _ = err.call.resume() assert set(df["sites"]) == {"S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10} +def test_invalid_url_wrapped_as_service_interrupted(): + """``httpx.InvalidURL`` inherits from ``Exception``, NOT from + ``httpx.HTTPError``. Without the widened catch in ``_issue`` / + ``_classify_chunk_error`` an oversize follow-up URL escapes as + raw ``InvalidURL`` and the user loses ``.call.resume()`` access + to the partial state. Mirror the ConnectError test.""" + import httpx as _httpx + + state = {"i": 0, "blow_up": True} + + def fetch(args): + i = state["i"] + state["i"] += 1 + if i == 2 and state["blow_up"]: + raise _httpx.InvalidURL("URL is too long: 65536 bytes > 65000") + return ( + pd.DataFrame({"sites": list(args["sites"])}), + _quota_response(500), + ) + + decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) + with pytest.raises(ServiceInterrupted) as excinfo: + decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10, "S5" * 10]}) + + err = excinfo.value + assert err.completed_chunks == 2 + assert err.call is not None + assert isinstance(err.__cause__, _httpx.InvalidURL) + # The top-level message must surface the underlying cause text so + # the user doesn't have to traverse ``__cause__`` to know what + # actually failed (previously the message was generic "Service + # error after K/N sub-requests; ... resume() once the upstream + # recovers", with the real "URL too long" only visible via + # ``.__cause__``). + assert "InvalidURL" in str(err) + assert "URL is too long" in str(err) + + def test_service_interrupted_exposes_partial_frame_and_response(): """Both ``QuotaExhausted`` AND ``ServiceInterrupted`` carry ``partial_frame`` / ``partial_response`` directly on the @@ -899,9 +830,9 @@ def fetch(args): def test_combine_chunk_responses_returns_independent_headers(): """The aggregated response's ``.headers`` must be a fresh - ``CaseInsensitiveDict`` — mutations by downstream callers - (logging hooks, metadata extensions) must not back-propagate into - the underlying chunk response's headers, which still live on + ``httpx.Headers`` — mutations by downstream callers (logging + hooks, metadata extensions) must not back-propagate into the + underlying chunk response's headers, which still live on ``ChunkedCall._chunks``.""" from dataretrieval.waterdata.chunking import _combine_chunk_responses @@ -930,7 +861,7 @@ def test_paginate_terminates_on_empty_string_cursor(): import datetime as _dt from unittest import mock as _mock - import requests as _requests + import httpx as _httpx from dataretrieval.waterdata import utils as _utils @@ -942,20 +873,20 @@ def test_paginate_terminates_on_empty_string_cursor(): "features": [{"id": "1", "properties": {"val": "a"}}], "links": [{"rel": "next", "href": ""}], } - resp = _mock.MagicMock(spec=_requests.Response) + resp = _mock.MagicMock(spec=_httpx.Response) resp.status_code = 200 resp.url = "https://example.com/items?limit=1" resp.elapsed = _dt.timedelta(seconds=0.1) resp.headers = {} resp.json.return_value = body_with_empty_next - client = _mock.MagicMock(spec=_requests.Session) + client = _mock.MagicMock(spec=_httpx.Client) client.send.return_value = resp - req = _mock.MagicMock(spec=_requests.PreparedRequest) + req = _mock.MagicMock(spec=_httpx.Request) req.method = "GET" req.headers = {} - req.body = None + req.content = b"" req.url = "https://example.com/items?limit=1" df, final = _utils._walk_pages(geopd=False, req=req, client=client) @@ -1036,35 +967,71 @@ def test_quota_exhausted_message_points_at_resume(): assert ".call.resume()" in msg -def test_request_bytes_rejects_non_sizable_body(): - """``_request_bytes`` requires a deterministic byte count up front; - silently treating an unknown body as zero would under-chunk and let - the request blow past the server's POST-body limit. Generators, - iterables, and file-like objects must surface as ``TypeError``.""" +def test_request_bytes_sums_url_and_content(): + """``_request_bytes`` returns ``len(str(url)) + len(content)``. + + ``httpx.Request`` always carries ``.content`` as ``bytes`` (the + constructor normalises ``data``/``json``/``content`` inputs), so + the chunker just needs to size that single attribute alongside + the URL. + """ + import httpx + from dataretrieval.waterdata.chunking import _request_bytes - class _FakeReqWithGenBody: - url = "https://example.com/foo" - body = (b"x" for _ in range(3)) + # GET request with no body + req = httpx.Request("GET", "https://x.example/ab") + assert _request_bytes(req) == len("https://x.example/ab") - with pytest.raises(TypeError, match="cannot size a request body"): - _request_bytes(_FakeReqWithGenBody()) + # POST request with content + req = httpx.Request("POST", "https://x.example/ab", content=b"cd") + assert _request_bytes(req) == len("https://x.example/ab") + 2 -def test_request_bytes_handles_supported_body_types(): - """Sanity-check the supported body types: None (GET), bytes (raw - POST), str (JSON-as-string POST).""" - from dataretrieval.waterdata.chunking import _request_bytes +def test_safe_request_bytes_treats_invalid_url_as_overflow(): + """``httpx.URL`` enforces a 64 KB cap per URL component and raises + ``httpx.InvalidURL`` for anything bigger — e.g. comma-joining all + California stream sites in one query. The planner's halving loop + must keep shrinking past that cap rather than crashing; the + contract is that ``_safe_request_bytes`` returns ``url_limit + 1`` + (a value strictly greater than the limit) when ``build_request`` + raises ``InvalidURL``.""" + import httpx + + from dataretrieval.waterdata.chunking import _safe_request_bytes + + def build_request(**kwargs): + raise httpx.InvalidURL("URL too long") + + url_limit = 8000 + assert _safe_request_bytes(build_request, {}, url_limit) == url_limit + 1 + + +def test_chunk_plan_handles_initial_url_overflow(): + """A user query whose unchunked URL exceeds the 64 KB + ``httpx.URL`` cap (e.g. 5000+ site IDs comma-joined) must not + crash ``ChunkPlan.__init__``; the planner falls back to a + worst-case sub-request URL for ``canonical_url`` and proceeds to + halve the over-limit axes normally.""" + import httpx + + real_build = _fake_build - class _Req: - def __init__(self, url, body): - self.url = url - self.body = body + def overflowing_build(**kwargs): + # Mimic httpx: any single sub-arg whose ``sites`` list has + # more than 2 entries fails URL construction (proxy for a + # 64 KB overflow at the worst case). + if len(kwargs.get("sites", [])) > 2: + raise httpx.InvalidURL("URL > 64 KB") + return real_build(**kwargs) - assert _request_bytes(_Req("ab", None)) == 2 - assert _request_bytes(_Req("ab", b"cd")) == 4 - assert _request_bytes(_Req("ab", "cd")) == 4 - assert _request_bytes(_Req("ab", bytearray(b"cd"))) == 4 + sites = ["S" * 10 + str(i) for i in range(8)] + plan = ChunkPlan({"sites": sites}, overflowing_build, url_limit=8000) + # Planner kept halving until every worst-case sub-arg had ≤2 sites. + assert all(len(c) <= 2 for c in plan.chunks["sites"]) + assert plan.total > 1 + # canonical_url fell back to a constructable worst-case URL. + assert plan.canonical_url is not None def test_multi_value_chunked_restores_canonical_url(): @@ -1169,7 +1136,7 @@ def test_joint_planner_url_construction_long_filter_and_long_sites(): over_limit = [] for sub_args in plan.iter_sub_args(): req = _construct_api_requests(**sub_args) - url_len = len(req.url) + (len(req.body) if req.body else 0) + url_len = len(str(req.url)) + len(req.content) if url_len > url_limit: over_limit.append((url_len, sub_args)) assert not over_limit, ( @@ -1194,11 +1161,9 @@ def test_joint_planner_url_construction_long_filter_and_long_sites(): def test_combine_chunk_frames_all_empty_preserves_geo_type(): - """Regression: when every chunk returns an empty frame, - ``_combine_chunk_frames`` must not downgrade an empty - ``GeoDataFrame`` to a plain ``DataFrame``. The whole reason the - function drops empties before concat is to prevent that downgrade - — the all-empty short-circuit was independently dropping it.""" + """An all-empty chunk list preserves the ``GeoDataFrame`` type. + Dropping empties before concat exists precisely to prevent type + downgrade; the all-empty branch must honor the same contract.""" pytest.importorskip("geopandas") import geopandas as gpd @@ -1212,10 +1177,10 @@ def test_combine_chunk_frames_all_empty_preserves_geo_type(): def test_combine_chunk_frames_single_frame_is_safe_to_mutate(): - """Regression: the single-completed-chunk fast path returned the - underlying chunk frame verbatim, so a caller mutating - ``call.partial_frame`` (documented as a live view) would mutate - ``_chunks[0][0]`` in place. The fast path now returns a copy.""" + """``_combine_chunk_frames`` returns a frame independent of its + input on the single-chunk fast path — a caller mutating + ``call.partial_frame`` (a live view) must not back-propagate into + the underlying ``_chunks[0][0]`` frame.""" from dataretrieval.waterdata.chunking import _combine_chunk_frames chunk = pd.DataFrame({"id": ["A", "B"], "value": [1, 2]}) @@ -1225,10 +1190,9 @@ def test_combine_chunk_frames_single_frame_is_safe_to_mutate(): def test_iter_sub_args_passthrough_yields_a_copy(): - """Regression: the no-axes passthrough yielded ``self.args`` - directly while the chunked branch did ``dict(self.args)``. A - ``fetch_once`` that mutated the dict it received would silently - corrupt ``ChunkPlan.args``. The passthrough now copies too.""" + """``ChunkPlan.iter_sub_args`` yields a fresh dict on every path + (passthrough and chunked), so a ``fetch_once`` that mutates the + dict it receives cannot corrupt ``ChunkPlan.args``.""" args = {"monitoring_location_id": ["USGS-A"], "limit": 100} plan = ChunkPlan(args, _fake_build, url_limit=8000) sub = next(plan.iter_sub_args()) @@ -1238,34 +1202,31 @@ def test_iter_sub_args_passthrough_yields_a_copy(): assert "new_key" not in plan.args -def test_quota_check_fires_after_every_chunk_not_just_first(): - """Regression: ``_check_quota_after_first`` was gated on - ``len(_chunks) == 1`` so it only fired after chunk 0; a concurrent - caller draining the window mid-call (or a partially-rolled-over - quota on resume) went undetected. The check now fires after every - non-final chunk.""" - # 4-chunk plan. Chunks 0 and 1 report plenty of remaining quota; - # chunk 2's response reports remaining=0 with one chunk still - # pending. The check must fire after chunk 2, NOT silently let - # chunk 3 hit a mid-stream 429. - responses = iter([500, 500, 0]) - calls: list[dict] = [] +def test_combine_chunk_responses_does_not_mutate_input_urls(): + """Regression for the _set_response_url aliasing bug. - def fetch(args): - calls.append(args) - return pd.DataFrame({"sites": list(args["sites"])}), _quota_response( - next(responses) - ) + ``_combine_chunk_responses`` shallow-copies the first response; + if the canonical-URL override is applied by mutating the bound + ``request.url``, the shallow alias back-propagates the URL change + into the underlying chunk-0 response — breaking the documented + 'input responses are not mutated' invariant. The fix is to swap + in a fresh ``httpx.Request`` rather than mutate the existing one. + """ + import httpx as _httpx - decorated = multi_value_chunked(build_request=_fake_build, url_limit=240)(fetch) - with pytest.raises(RequestExceedsQuota) as excinfo: - decorated({"sites": ["S1" * 10, "S2" * 10, "S3" * 10, "S4" * 10]}) - err = excinfo.value - assert err.planned_chunks == 4 - # 3 completed + 0 remaining = 3 available; 1 pending; deficit 1. - assert err.available == 3 - assert err.deficit == 1 - assert len(calls) == 3, "only chunks 0-2 should have been issued" - # .call carries the in-flight call so the user can recover. - assert err.call is not None - assert err.call.completed_chunks == 3 + from dataretrieval.waterdata.chunking import _combine_chunk_responses + + req1 = _httpx.Request("GET", "https://example.com/chunk0") + req2 = _httpx.Request("GET", "https://example.com/chunk1") + r1 = _httpx.Response(200, request=req1) + r2 = _httpx.Response(200, request=req2) + + out = _combine_chunk_responses( + [r1, r2], canonical_url="https://canonical.example/full" + ) + assert str(out.url) == "https://canonical.example/full" + # The inputs and their bound requests must be untouched. + assert str(r1.url) == "https://example.com/chunk0" + assert str(r2.url) == "https://example.com/chunk1" + assert str(req1.url) == "https://example.com/chunk0" + assert str(req2.url) == "https://example.com/chunk1" diff --git a/tests/waterdata_filters_test.py b/tests/waterdata_filters_test.py index 9d9d183e..32879318 100644 --- a/tests/waterdata_filters_test.py +++ b/tests/waterdata_filters_test.py @@ -14,7 +14,7 @@ def _query_params(prepared_request): - return parse_qs(urlsplit(prepared_request.url).query) + return parse_qs(urlsplit(str(prepared_request.url)).query) def _fake_prepared_request(url="https://example.test"): diff --git a/tests/waterdata_progress_test.py b/tests/waterdata_progress_test.py index 14a98839..faa61630 100644 --- a/tests/waterdata_progress_test.py +++ b/tests/waterdata_progress_test.py @@ -11,8 +11,8 @@ import types from unittest import mock +import httpx import pytest -import requests from dataretrieval.waterdata import _progress from dataretrieval.waterdata._progress import ( @@ -305,11 +305,11 @@ def test_walk_pages_reports_pages_and_rate_limit(): ) resp2 = _resp([{"id": "2", "properties": {"v": "b"}}], rate_remaining="4998") - client = mock.MagicMock(spec=requests.Session) + client = mock.MagicMock(spec=httpx.Client) client.send.return_value = resp1 client.request.return_value = resp2 - req = mock.MagicMock(spec=requests.PreparedRequest) + req = mock.MagicMock(spec=httpx.Request) req.method = "GET" req.headers = {} req.url = "https://example.com/p1" @@ -330,10 +330,10 @@ def test_walk_pages_reports_pages_and_rate_limit(): def test_walk_pages_without_context_does_not_error(): # No active reporter: pagination must still work and stay silent. resp = _resp([{"id": "1", "properties": {"v": "a"}}]) - client = mock.MagicMock(spec=requests.Session) + client = mock.MagicMock(spec=httpx.Client) client.send.return_value = resp - req = mock.MagicMock(spec=requests.PreparedRequest) + req = mock.MagicMock(spec=httpx.Request) req.method = "GET" req.headers = {} req.url = "https://example.com/p1" @@ -350,11 +350,11 @@ def test_broken_progress_stream_does_not_truncate_pagination(): [{"id": "1", "properties": {"v": "a"}}], next_url="https://example.com/p2" ) resp2 = _resp([{"id": "2", "properties": {"v": "b"}}]) - client = mock.MagicMock(spec=requests.Session) + client = mock.MagicMock(spec=httpx.Client) client.send.return_value = resp1 client.request.return_value = resp2 - req = mock.MagicMock(spec=requests.PreparedRequest) + req = mock.MagicMock(spec=httpx.Request) req.method = "GET" req.headers = {} req.url = "https://example.com/p1" diff --git a/tests/waterdata_ratings_test.py b/tests/waterdata_ratings_test.py index fcead65d..9cbb4b70 100644 --- a/tests/waterdata_ratings_test.py +++ b/tests/waterdata_ratings_test.py @@ -1,3 +1,4 @@ +import re import sys from urllib.parse import parse_qs, urlsplit @@ -10,6 +11,15 @@ from dataretrieval.waterdata import get_ratings from dataretrieval.waterdata.ratings import _build_filter +# pytest-httpx matches URL strings exactly (including query). For the +# ratings tests we want a "match this endpoint, ignore the params" +# fixture so the assertions can drill into the captured params +# afterwards without coupling the registration to the implementation's +# parameter order. ``url=STAC_SEARCH_RE`` does that. +STAC_SEARCH_RE = re.compile( + r"^https://api\.waterdata\.usgs\.gov/stac/v0/search(\?.*)?$" +) + def test_build_filter_single_site_single_type(): f = _build_filter("USGS-01104475", "exsa") @@ -77,14 +87,16 @@ def _stub_search_response(): } -def test_get_ratings_mocked_search_and_download(requests_mock, tmp_path): +def test_get_ratings_mocked_search_and_download(httpx_mock, tmp_path): """End-to-end happy path with mocked STAC search + RDB download.""" - requests_mock.get( - "https://api.waterdata.usgs.gov/stac/v0/search", + httpx_mock.add_response( + method="GET", + url=STAC_SEARCH_RE, json=_stub_search_response(), ) - requests_mock.get( - "https://api.waterdata.usgs.gov/stac-files/ratings/USGS.01104475.exsa.rdb", + httpx_mock.add_response( + method="GET", + url="https://api.waterdata.usgs.gov/stac-files/ratings/USGS.01104475.exsa.rdb", text=_SAMPLE_RDB, ) @@ -100,22 +112,23 @@ def test_get_ratings_mocked_search_and_download(requests_mock, tmp_path): assert len(df) == 3 # Server-side filter should pin the single requested file_type. - sent = requests_mock.request_history[0] - qs = parse_qs(urlsplit(sent.url).query) + sent = httpx_mock.get_requests()[0] + qs = parse_qs(urlsplit(str(sent.url)).query) assert "file_type = 'exsa'" in qs["filter"][0] assert "monitoring_location_id IN ('USGS-01104475')" in qs["filter"][0] -def test_get_ratings_attaches_rdb_comment_and_url(requests_mock, tmp_path): +def test_get_ratings_attaches_rdb_comment_and_url(httpx_mock, tmp_path): """Each parsed frame should carry its RDB header + source URL in df.attrs.""" - requests_mock.get( - "https://api.waterdata.usgs.gov/stac/v0/search", + httpx_mock.add_response( + method="GET", + url=STAC_SEARCH_RE, json=_stub_search_response(), ) asset_url = ( "https://api.waterdata.usgs.gov/stac-files/ratings/USGS.01104475.exsa.rdb" ) - requests_mock.get(asset_url, text=_SAMPLE_RDB) + httpx_mock.add_response(method="GET", url=asset_url, text=_SAMPLE_RDB) out = get_ratings( monitoring_location_id="USGS-01104475", @@ -131,9 +144,10 @@ def test_get_ratings_attaches_rdb_comment_and_url(requests_mock, tmp_path): assert df.attrs["url"] == asset_url -def test_get_ratings_download_and_parse_false_returns_features(requests_mock): - requests_mock.get( - "https://api.waterdata.usgs.gov/stac/v0/search", +def test_get_ratings_download_and_parse_false_returns_features(httpx_mock): + httpx_mock.add_response( + method="GET", + url=STAC_SEARCH_RE, json=_stub_search_response(), ) features = get_ratings( @@ -144,10 +158,11 @@ def test_get_ratings_download_and_parse_false_returns_features(requests_mock): assert features[0]["id"] == "USGS-01104475.exsa.rdb" -def test_get_ratings_multi_type_filters_via_property(requests_mock, tmp_path): +def test_get_ratings_multi_type_filters_via_property(httpx_mock, tmp_path): """File_type list: server filter omits it; local filter reads the property.""" - requests_mock.get( - "https://api.waterdata.usgs.gov/stac/v0/search", + httpx_mock.add_response( + method="GET", + url=STAC_SEARCH_RE, json={ "features": [ { @@ -169,8 +184,12 @@ def test_get_ratings_multi_type_filters_via_property(requests_mock, tmp_path): }, ) # Only mock the two URLs we expect to be downloaded. - requests_mock.get("https://x.example/X.exsa.rdb", text=_SAMPLE_RDB) - requests_mock.get("https://x.example/X.corr.rdb", text=_SAMPLE_RDB) + httpx_mock.add_response( + method="GET", url="https://x.example/X.exsa.rdb", text=_SAMPLE_RDB + ) + httpx_mock.add_response( + method="GET", url="https://x.example/X.corr.rdb", text=_SAMPLE_RDB + ) out = get_ratings( monitoring_location_id="USGS-X", @@ -180,6 +199,6 @@ def test_get_ratings_multi_type_filters_via_property(requests_mock, tmp_path): assert set(out) == {"USGS-X.exsa.rdb", "USGS-X.corr.rdb"} # Server-side filter must NOT include file_type for multi-type requests. - search_req = requests_mock.request_history[0] - qs = parse_qs(urlsplit(search_req.url).query) + search_req = httpx_mock.get_requests()[0] + qs = parse_qs(urlsplit(str(search_req.url)).query) assert "file_type" not in qs["filter"][0] diff --git a/tests/waterdata_test.py b/tests/waterdata_test.py index 24eb6eff..09f66aa5 100644 --- a/tests/waterdata_test.py +++ b/tests/waterdata_test.py @@ -44,27 +44,30 @@ # try. The marker is attached to every test in the module, but the # patterns match only traces produced by real network round-trips # (``_raise_for_non_200`` output, ``requests`` exceptions), so tests -# using ``requests_mock`` or ``mock.patch`` are no-ops for the rerun. +# using ``httpx_mock`` or ``mock.patch`` are no-ops for the rerun. pytestmark = pytest.mark.flaky( reruns=2, reruns_delay=5, only_rerun=[ r"(?:RateLimited|RuntimeError):\s*(?:429|5\d\d):", # _raise_for_non_200 output - r"ConnectionError", + r"Connect(ion)?Error", # requests' ConnectionError + httpx' ConnectError r"ReadTimeout|ConnectTimeout|Timeout", ], ) -def mock_request(requests_mock, request_url, file_path): +def mock_request(httpx_mock, request_url, file_path): """Mock request code""" with open(file_path) as text: - requests_mock.get( - request_url, text=text.read(), headers={"mock_header": "value"} + httpx_mock.add_response( + method="GET", + url=request_url, + text=text.read(), + headers={"mock_header": "value"}, ) -def test_mock_get_samples(requests_mock): +def test_mock_get_samples(httpx_mock): """Tests USGS Samples query""" request_url = ( "https://api.waterdata.usgs.gov/samples-data/results/fullphyschem?" @@ -72,7 +75,7 @@ def test_mock_get_samples(requests_mock): "&activityStartDateUpper=2024-12-31&monitoringLocationIdentifier=USGS-05406500&mimeType=text%2Fcsv" ) response_file_path = "tests/data/samples_results.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_samples( service="results", profile="fullphyschem", @@ -86,19 +89,19 @@ def test_mock_get_samples(requests_mock): assert df.shape == (67, 187) assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None assert df["Activity_StartDateTime"].notna().any() -def test_mock_get_samples_summary(requests_mock): +def test_mock_get_samples_summary(httpx_mock): """Tests USGS Samples summary query""" request_url = ( "https://api.waterdata.usgs.gov/samples-data/summary/USGS-04183500" "?mimeType=text%2Fcsv" ) response_file_path = "tests/data/samples_summary.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_samples_summary(monitoringLocationIdentifier="USGS-04183500") assert type(df) is DataFrame expected_columns = { @@ -115,7 +118,7 @@ def test_mock_get_samples_summary(requests_mock): assert (df["monitoringLocationIdentifier"] == "USGS-04183500").all() assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None @@ -141,8 +144,8 @@ def test_construct_api_requests_multivalue_get(): parameter_code=["00060", "00065"], ) assert req.method == "GET" - assert "monitoring_location_id=USGS-05427718%2CUSGS-05427719" in req.url - assert "parameter_code=00060%2C00065" in req.url + assert "monitoring_location_id=USGS-05427718%2CUSGS-05427719" in str(req.url) + assert "parameter_code=00060%2C00065" in str(req.url) def test_construct_api_requests_monitoring_locations_post(): @@ -154,7 +157,7 @@ def test_construct_api_requests_monitoring_locations_post(): assert req.method == "POST" assert req.headers["Content-Type"] == "application/query-cql-json" - body = json.loads(req.body) + body = json.loads(req.content) # Top-level shape: AND over a list of per-param predicates. assert body["op"] == "and" assert isinstance(body["args"], list) and len(body["args"]) == 1 @@ -175,8 +178,8 @@ def test_construct_api_requests_single_value_stays_get(): parameter_code="00060", ) assert req.method == "GET" - assert "monitoring_location_id=USGS-05427718" in req.url - assert "%2C" not in req.url # no comma-encoded multi-value + assert "monitoring_location_id=USGS-05427718" in str(req.url) + assert "%2C" not in str(req.url) # no comma-encoded multi-value def test_construct_api_requests_numeric_list_joins_with_str(): @@ -189,7 +192,7 @@ def test_construct_api_requests_numeric_list_joins_with_str(): water_year=[2020, 2021], ) assert req.method == "GET" - assert "water_year=2020%2C2021" in req.url + assert "water_year=2020%2C2021" in str(req.url) def test_construct_api_requests_two_element_date_list_becomes_interval(): @@ -204,7 +207,7 @@ def test_construct_api_requests_two_element_date_list_becomes_interval(): ) assert req.method == "GET" # `/` URL-encodes to %2F. Confirms _format_api_dates ran before the join. - assert "time=2024-01-01%2F2024-01-31" in req.url + assert "time=2024-01-01%2F2024-01-31" in str(req.url) def test_samples_results(): diff --git a/tests/waterdata_utils_test.py b/tests/waterdata_utils_test.py index c135115c..bb5ece10 100644 --- a/tests/waterdata_utils_test.py +++ b/tests/waterdata_utils_test.py @@ -2,9 +2,9 @@ import logging from unittest import mock +import httpx import pandas as pd import pytest -import requests import dataretrieval.waterdata.utils as _utils_module from dataretrieval.waterdata.chunking import RateLimited, ServiceUnavailable @@ -74,13 +74,13 @@ def test_walk_pages_multiple_mocked(): resp2.status_code = 200 # Mock client (Session) - mock_client = mock.MagicMock(spec=requests.Session) + mock_client = mock.MagicMock(spec=httpx.Client) # First call to send() returns resp1, then call to request() in loop returns resp2 mock_client.send.return_value = resp1 mock_client.request.return_value = resp2 # Mock request (PreparedRequest) - mock_req = mock.MagicMock(spec=requests.PreparedRequest) + mock_req = mock.MagicMock(spec=httpx.Request) mock_req.method = "GET" mock_req.headers = {} mock_req.url = "https://example.com/page1" @@ -115,14 +115,14 @@ def _walk_pages_with_failure(failure_resp_or_exc): """Run _walk_pages where page 1 succeeds and page 2 fails as given.""" resp1 = _resp_ok([{"id": "1", "properties": {"val": "a"}}]) - mock_client = mock.MagicMock(spec=requests.Session) + mock_client = mock.MagicMock(spec=httpx.Client) mock_client.send.return_value = resp1 if isinstance(failure_resp_or_exc, BaseException): mock_client.request.side_effect = failure_resp_or_exc else: mock_client.request.return_value = failure_resp_or_exc - mock_req = mock.MagicMock(spec=requests.PreparedRequest) + mock_req = mock.MagicMock(spec=httpx.Request) mock_req.method = "GET" mock_req.headers = {} mock_req.url = "https://example.com/page1" @@ -135,21 +135,21 @@ def test_walk_pages_raises_on_connection_error_mid_pagination(): chained, and the wrapper message must include recovery guidance that is NOT rate-limit-specific (no quota window involved).""" with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - _walk_pages_with_failure(requests.ConnectionError("boom")) + _walk_pages_with_failure(httpx.ConnectError("boom")) msg = str(excinfo.value) - assert isinstance(excinfo.value.__cause__, requests.ConnectionError) + assert isinstance(excinfo.value.__cause__, httpx.ConnectError) assert "boom" in msg assert "retry the request" in msg assert "rate-limit window" not in msg def test_walk_pages_raises_with_class_name_when_cause_stringifies_empty(): - """Some ``requests`` exceptions (e.g. ``Timeout()`` with no args) + """Some ``httpx`` exceptions (e.g. ``TimeoutException("")``) stringify to ``""``. The wrapper must still produce an informative message — fall back to the exception class name.""" with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: - _walk_pages_with_failure(requests.Timeout()) + _walk_pages_with_failure(httpx.TimeoutException("")) msg = str(excinfo.value) assert "Timeout" in msg, msg @@ -206,10 +206,10 @@ def test_walk_pages_wraps_initial_page_parse_error(): # Body is unparseable JSON (gateway HTML page, truncated stream). resp.json.side_effect = json.JSONDecodeError("Expecting value", "...", 0) - mock_client = mock.MagicMock(spec=requests.Session) + mock_client = mock.MagicMock(spec=httpx.Client) mock_client.send.return_value = resp - mock_req = mock.MagicMock(spec=requests.PreparedRequest) + mock_req = mock.MagicMock(spec=httpx.Request) mock_req.method = "GET" mock_req.headers = {} mock_req.url = "https://example.com/page1" @@ -270,11 +270,11 @@ def test_walk_pages_does_not_mutate_initial_response(): "links": [], } - mock_client = mock.MagicMock(spec=requests.Session) + mock_client = mock.MagicMock(spec=httpx.Client) mock_client.send.return_value = page1 mock_client.request.return_value = page2 - mock_req = mock.MagicMock(spec=requests.PreparedRequest) + mock_req = mock.MagicMock(spec=httpx.Request) mock_req.method = "GET" mock_req.headers = {} mock_req.url = "https://example.com/page1" @@ -324,7 +324,7 @@ def _run_get_stats_data_with_failure(failure_resp_or_exc, monkeypatch): mock.MagicMock(return_value=pd.DataFrame()), ) - mock_client = mock.MagicMock(spec=requests.Session) + mock_client = mock.MagicMock(spec=httpx.Client) mock_client.send.return_value = _stats_initial_ok() if isinstance(failure_resp_or_exc, BaseException): mock_client.request.side_effect = failure_resp_or_exc @@ -347,11 +347,11 @@ def test_get_stats_data_raises_on_mid_pagination_failure(monkeypatch): follow-up callback is wired into ``_paginate`` correctly.""" with pytest.raises(RuntimeError, match="Paginated request failed") as excinfo: _run_get_stats_data_with_failure( - requests.ConnectionError("stats-boom"), + httpx.ConnectError("stats-boom"), monkeypatch, ) - assert isinstance(excinfo.value.__cause__, requests.ConnectionError) + assert isinstance(excinfo.value.__cause__, httpx.ConnectError) assert "stats-boom" in str(excinfo.value) @@ -605,12 +605,16 @@ def test_format_api_dates_rejects_mapping(): def _make_response(status, body, reason=None, content_type="text/html"): - resp = requests.Response() - resp.status_code = status - resp.reason = reason - resp._content = body.encode("utf-8") - resp.headers["Content-Type"] = content_type - return resp + headers = {"Content-Type": content_type} + extensions = {} + if reason is not None: + extensions["reason_phrase"] = reason.encode("utf-8") + return httpx.Response( + status_code=status, + content=body.encode("utf-8"), + headers=headers, + extensions=extensions, + ) def test_error_body_handles_non_json_html_response(): @@ -730,3 +734,41 @@ def test_raise_for_non_200_still_raises_bare_runtimeerror_for_other_4xx(): # ServiceUnavailable. Both subclass RuntimeError, so a plain # ``pytest.raises(RuntimeError)`` would match either. assert type(excinfo.value) is RuntimeError + + +def test_next_req_url_rejects_cross_host(): + """``_next_req_url`` must refuse to follow a next-page link to a + different host. The original request's headers (including any + auth-like artifacts) were minted for the original host; following + a server-supplied cross-host URL would leak them — and the URL + itself could be sensitive.""" + from dataretrieval.waterdata.utils import _next_req_url + + resp = mock.MagicMock() + resp.url = httpx.URL("https://api.waterdata.usgs.gov/page1") + body = { + "numberReturned": 1, + "features": [{"id": "1"}], + "links": [{"rel": "next", "href": "https://evil.example.org/secret"}], + } + with pytest.raises(RuntimeError, match="cross-host next-page"): + _next_req_url(resp, body=body) + + +def test_check_ogc_requests_raises_typed_on_5xx(httpx_mock): + """``_check_ogc_requests`` previously called ``resp.raise_for_status()``, + which leaks raw ``httpx.HTTPStatusError``. Now routes through + ``_raise_for_non_200`` so callers see ``ServiceUnavailable`` / + ``RateLimited`` / ``RuntimeError`` — the same typed contract as + the main data path.""" + from dataretrieval.waterdata.chunking import ServiceUnavailable + from dataretrieval.waterdata.utils import OGC_API_URL, _check_ogc_requests + + httpx_mock.add_response( + method="GET", + url=f"{OGC_API_URL}/collections/daily/schema", + status_code=503, + json={"code": "ServiceUnavailable", "description": "maintenance window"}, + ) + with pytest.raises(ServiceUnavailable): + _check_ogc_requests(endpoint="daily", req_type="schema") diff --git a/tests/waterservices_test.py b/tests/waterservices_test.py index 2126b4d1..874c2a0e 100644 --- a/tests/waterservices_test.py +++ b/tests/waterservices_test.py @@ -57,18 +57,20 @@ def test_query_waterservices_validation(): assert str(type_error.value) == "Service not recognized" -def test_query_validation(requests_mock): +def test_query_validation(httpx_mock): request_url = ( "https://waterservices.usgs.gov/nwis/stat?sites=bad_site_id&format=rdb" ) - requests_mock.get(request_url, status_code=400) + httpx_mock.add_response(method="GET", url=request_url, status_code=400) with pytest.raises(ValueError) as type_error: get_stats(sites="bad_site_id") assert request_url in str(type_error) request_url = "https://waterservices.usgs.gov/nwis/stat?sites=123456&format=rdb" - requests_mock.get( - request_url, text="No sites/data found using the selection criteria specified" + httpx_mock.add_response( + method="GET", + url=request_url, + text="No sites/data found using the selection criteria specified", ) with pytest.raises(NoSitesError) as no_sites_error: get_stats(sites="123456") @@ -82,7 +84,7 @@ def test_get_record_validation(): assert str(type_error.value) == "Unrecognized service: not_a_service" -def test_get_dv(requests_mock): +def test_get_dv(httpx_mock): """Verify get_dv builds the expected request URL and returns a DataFrame.""" format = "json" site = "01491000%2C01645000" @@ -91,7 +93,7 @@ def test_get_dv(requests_mock): f"&startDT=2020-02-14&endDT=2020-02-15&sites={site}" ) response_file_path = "tests/data/waterservices_dv.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_dv( sites=["01491000", "01645000"], start="2020-02-14", end="2020-02-15" ) @@ -100,11 +102,11 @@ def test_get_dv(requests_mock): raise TypeError(f"{type(df)} is not DataFrame base class type") assert df.size == 8 - assert_metadata(requests_mock, request_url, md, site, None, format) + assert_metadata(httpx_mock, request_url, md, site, None, format) @pytest.mark.parametrize("site_input_type_list", [True, False]) -def test_get_dv_site_value_types(requests_mock, site_input_type_list): +def test_get_dv_site_value_types(httpx_mock, site_input_type_list): """Tests get_dv method for valid input types for the 'sites' parameter""" _format = "json" site = "01491000" @@ -113,7 +115,7 @@ def test_get_dv_site_value_types(requests_mock, site_input_type_list): f"&startDT=2020-02-14&endDT=2020-02-15&sites={site}" ) response_file_path = "tests/data/waterservices_dv.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) if site_input_type_list: sites = [site] else: @@ -125,7 +127,7 @@ def test_get_dv_site_value_types(requests_mock, site_input_type_list): assert df.size == 8 -def test_get_iv(requests_mock): +def test_get_iv(httpx_mock): """Verify get_iv builds the expected request URL and returns a DataFrame.""" format = "json" site = "01491000%2C01645000" @@ -134,7 +136,7 @@ def test_get_iv(requests_mock): f"&startDT=2019-02-14&endDT=2020-02-15&sites={site}" ) response_file_path = "tests/data/waterservices_iv.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_iv( sites=["01491000", "01645000"], start="2019-02-14", end="2020-02-15" ) @@ -143,11 +145,11 @@ def test_get_iv(requests_mock): assert df.size == 563380 assert md.url == request_url - assert_metadata(requests_mock, request_url, md, site, None, format) + assert_metadata(httpx_mock, request_url, md, site, None, format) @pytest.mark.parametrize("site_input_type_list", [True, False]) -def test_get_iv_site_value_types(requests_mock, site_input_type_list): +def test_get_iv_site_value_types(httpx_mock, site_input_type_list): """Tests get_iv method for valid input type for the 'sites' parameter""" _format = "json" site = "01491000" @@ -156,7 +158,7 @@ def test_get_iv_site_value_types(requests_mock, site_input_type_list): f"&startDT=2019-02-14&endDT=2020-02-15&sites={site}" ) response_file_path = "tests/data/waterservices_iv.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) if site_input_type_list: sites = [site] else: @@ -168,7 +170,7 @@ def test_get_iv_site_value_types(requests_mock, site_input_type_list): assert md.url == request_url -def test_get_info(requests_mock): +def test_get_info(httpx_mock): """ Verify get_info builds the expected request URL and returns a DataFrame. Note that only sites and format are passed as query params @@ -179,7 +181,7 @@ def test_get_info(requests_mock): parameter_cd = "00618" request_url = f"https://waterservices.usgs.gov/nwis/site?sites={site}¶meterCd={parameter_cd}&siteOutput=Expanded&format={format}" response_file_path = "tests/data/waterservices_site.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_info(sites=["01491000", "01645000"], parameterCd="00618") if not isinstance(df, DataFrame): raise TypeError(f"{type(df)} is not DataFrame base class type") @@ -194,10 +196,10 @@ def test_get_info(requests_mock): assert df.size == size assert md.url == request_url - assert_metadata(requests_mock, request_url, md, site, [parameter_cd], format) + assert_metadata(httpx_mock, request_url, md, site, [parameter_cd], format) -def test_get_discharge_peaks(requests_mock): +def test_get_discharge_peaks(httpx_mock): """Verify get_discharge_peaks builds the expected URL and returns a DataFrame.""" format = "rdb" site = "01594440" @@ -206,17 +208,17 @@ def test_get_discharge_peaks(requests_mock): "&begin_date=2000-02-14&end_date=2020-02-15" ) response_file_path = "tests/data/waterservices_peaks.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_discharge_peaks(sites=[site], start="2000-02-14", end="2020-02-15") if not isinstance(df, DataFrame): raise TypeError(f"{type(df)} is not DataFrame base class type") assert df.size == 240 - assert_metadata(requests_mock, request_url, md, site, None, format) + assert_metadata(httpx_mock, request_url, md, site, None, format) @pytest.mark.parametrize("site_input_type_list", [True, False]) -def test_get_discharge_peaks_sites_value_types(requests_mock, site_input_type_list): +def test_get_discharge_peaks_sites_value_types(httpx_mock, site_input_type_list): """Tests get_discharge_peaks for valid input types of the 'sites' parameter""" _format = "rdb" @@ -226,7 +228,7 @@ def test_get_discharge_peaks_sites_value_types(requests_mock, site_input_type_li "&begin_date=2000-02-14&end_date=2020-02-15" ) response_file_path = "tests/data/waterservices_peaks.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) if site_input_type_list: sites = [site] else: @@ -249,22 +251,22 @@ def test_get_ratings_validation(): ) -def test_get_ratings(requests_mock): +def test_get_ratings(httpx_mock): """Verify get_ratings builds the expected URL and returns a DataFrame.""" format = "rdb" site = "01594440" request_url = f"https://nwis.waterdata.usgs.gov/nwisweb/get_ratings/?site_no={site}&file_type=base" response_file_path = "tests/data/waterservices_ratings.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_ratings(site_no=site) if not isinstance(df, DataFrame): raise TypeError(f"{type(df)} is not DataFrame base class type") assert df.size == 33 - assert_metadata(requests_mock, request_url, md, site, None, format) + assert_metadata(httpx_mock, request_url, md, site, None, format) -def test_what_sites(requests_mock): +def test_what_sites(httpx_mock): """Verify what_sites builds the expected URL and returns a DataFrame.""" size = 2472 format = "rdb" @@ -275,7 +277,7 @@ def test_what_sites(requests_mock): f"¶meterCd={parameter_cd}&hasDataTypeCd=dv&format={format}" ) response_file_path = "tests/data/nwis_sites.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_sites( bBox=[-83.0, 36.5, -81.0, 38.5], @@ -298,25 +300,25 @@ def test_what_sites(requests_mock): size += len(df) assert df.size == size - assert_metadata(requests_mock, request_url, md, None, parameter_cd_list, format) + assert_metadata(httpx_mock, request_url, md, None, parameter_cd_list, format) -def test_get_stats(requests_mock): +def test_get_stats(httpx_mock): """Verify get_stats builds the expected URL and returns a DataFrame.""" format = "rdb" request_url = f"https://waterservices.usgs.gov/nwis/stat?sites=01491000%2C01645000&format={format}" response_file_path = "tests/data/waterservices_stats.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_stats(sites=["01491000", "01645000"]) if not isinstance(df, DataFrame): raise TypeError(f"{type(df)} is not DataFrame base class type") assert df.size == 51936 - assert_metadata(requests_mock, request_url, md, None, None, format) + assert_metadata(httpx_mock, request_url, md, None, None, format) @pytest.mark.parametrize("site_input_type_list", [True, False]) -def test_get_stats_site_value_types(requests_mock, site_input_type_list): +def test_get_stats_site_value_types(httpx_mock, site_input_type_list): """Tests get_stats method for valid input types for the 'sites' parameter""" _format = "rdb" site = "01491000" @@ -324,7 +326,7 @@ def test_get_stats_site_value_types(requests_mock, site_input_type_list): f"https://waterservices.usgs.gov/nwis/stat?sites={site}&format={_format}" ) response_file_path = "tests/data/waterservices_stats.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) if site_input_type_list: sites = [site] else: @@ -335,23 +337,28 @@ def test_get_stats_site_value_types(requests_mock, site_input_type_list): assert df.size == 51936 -def mock_request(requests_mock, request_url, file_path): +def mock_request(httpx_mock, request_url, file_path): with open(file_path) as text: - requests_mock.get( - request_url, text=text.read(), headers={"mock_header": "value"} + httpx_mock.add_response( + method="GET", + url=request_url, + text=text.read(), + headers={"mock_header": "value"}, ) -def assert_metadata(requests_mock, request_url, md, site, parameter_cd, format): +def assert_metadata(httpx_mock, request_url, md, site, parameter_cd, format): assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" if site is not None: site_request_url = ( f"https://waterservices.usgs.gov/nwis/site?sites={site}&format=rdb" ) with open("tests/data/waterservices_site.txt") as text: - requests_mock.get(site_request_url, text=text.read()) + httpx_mock.add_response( + method="GET", url=site_request_url, text=text.read() + ) site_info, _ = md.site_info if not isinstance(site_info, DataFrame): raise AssertionError(f"{type(site_info)} is not DataFrame base class type") diff --git a/tests/wqp_test.py b/tests/wqp_test.py index f432ab26..356f7ac8 100644 --- a/tests/wqp_test.py +++ b/tests/wqp_test.py @@ -17,7 +17,7 @@ ) -def test_get_results(requests_mock): +def test_get_results(httpx_mock): """Tests water quality portal ratings query""" request_url = ( "https://www.waterqualitydata.us/data/Result/Search?siteid=WIDNR_WQX-10032762" @@ -25,7 +25,7 @@ def test_get_results(requests_mock): "&mimeType=csv" ) response_file_path = "tests/data/wqp_results.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_results( siteid="WIDNR_WQX-10032762", characteristicName="Specific conductance", @@ -36,12 +36,12 @@ def test_get_results(requests_mock): assert df.shape == (5, 65) assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None assert df["ActivityStartDateTime"].notna().all() -def test_get_results_WQX3(requests_mock): +def test_get_results_WQX3(httpx_mock): """Tests water quality portal results query with new WQX3.0 profile""" request_url = ( "https://www.waterqualitydata.us/wqx3/Result/search?siteid=WIDNR_WQX-10032762" @@ -50,7 +50,7 @@ def test_get_results_WQX3(requests_mock): "&dataProfile=fullPhysChem" ) response_file_path = "tests/data/wqp3_results.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = get_results( legacy=False, siteid="WIDNR_WQX-10032762", @@ -62,151 +62,154 @@ def test_get_results_WQX3(requests_mock): assert df.shape == (5, 186) assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None assert df["Activity_StartDateTime"].notna().all() -def test_what_sites(requests_mock): +def test_what_sites(httpx_mock): """Tests Water quality portal sites query""" request_url = ( "https://www.waterqualitydata.us/data/Station/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_sites.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_sites(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 239868 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_organizations(requests_mock): +def test_what_organizations(httpx_mock): """Tests Water quality portal organizations query""" request_url = ( "https://www.waterqualitydata.us/data/Organization/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_organizations.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_organizations(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 576 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_projects(requests_mock): +def test_what_projects(httpx_mock): """Tests Water quality portal projects query""" request_url = ( "https://www.waterqualitydata.us/data/Project/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_projects.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_projects(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 530 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_activities(requests_mock): +def test_what_activities(httpx_mock): """Tests Water quality portal activities query""" request_url = ( "https://www.waterqualitydata.us/data/Activity/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_activities.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_activities(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 5087443 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_detection_limits(requests_mock): +def test_what_detection_limits(httpx_mock): """Tests Water quality portal detection limits query""" request_url = ( "https://www.waterqualitydata.us/data/ResultDetectionQuantitationLimit/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_detection_limits.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_detection_limits(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 98770 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_habitat_metrics(requests_mock): +def test_what_habitat_metrics(httpx_mock): """Tests Water quality portal habitat metrics query""" request_url = ( "https://www.waterqualitydata.us/data/BiologicalMetric/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_habitat_metrics.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_habitat_metrics(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 48114 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_project_weights(requests_mock): +def test_what_project_weights(httpx_mock): """Tests Water quality portal project weights query""" request_url = ( "https://www.waterqualitydata.us/data/ProjectMonitoringLocationWeighting/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_project_weights.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_project_weights(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 33098 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def test_what_activity_metrics(requests_mock): +def test_what_activity_metrics(httpx_mock): """Tests Water quality portal activity metrics query""" request_url = ( "https://www.waterqualitydata.us/data/ActivityMetric/Search?statecode=US%3A34&characteristicName=Chloride" "&mimeType=csv" ) response_file_path = "tests/data/wqp_activity_metrics.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, md = what_activity_metrics(statecode="US:34", characteristicName="Chloride") assert type(df) is DataFrame assert df.size == 378 assert md.url == request_url assert isinstance(md.query_time, datetime.timedelta) - assert md.header == {"mock_header": "value"} + assert md.header.get("mock_header") == "value" assert md.comment is None -def mock_request(requests_mock, request_url, file_path): +def mock_request(httpx_mock, request_url, file_path): with open(file_path) as text: - requests_mock.get( - request_url, text=text.read(), headers={"mock_header": "value"} + httpx_mock.add_response( + method="GET", + url=request_url, + text=text.read(), + headers={"mock_header": "value"}, ) @@ -220,7 +223,7 @@ def test_check_kwargs(): kwargs = _check_kwargs(kwargs) -def test_get_results_wqx3_preserves_user_dataProfile(requests_mock): +def test_get_results_wqx3_preserves_user_dataProfile(httpx_mock): """A valid user-supplied WQX3.0 profile must not be overwritten. Regression: previously the `else` branch of the `dataProfile` validation @@ -232,11 +235,11 @@ def test_get_results_wqx3_preserves_user_dataProfile(requests_mock): "siteid=UTAHDWQ_WQX-4993795&mimeType=csv&dataProfile=narrow" ) response_file_path = "tests/data/wqp3_results.txt" - mock_request(requests_mock, request_url, response_file_path) + mock_request(httpx_mock, request_url, response_file_path) df, _md = get_results( legacy=False, siteid="UTAHDWQ_WQX-4993795", dataProfile="narrow" ) assert isinstance(df, DataFrame) - sent = requests_mock.request_history[-1] - assert sent.qs.get("dataprofile") == ["narrow"] + sent = httpx_mock.get_requests()[-1] + assert sent.url.params.get("dataProfile") == "narrow"