From 6f7b54efbc0f46989cbc5c50366e1516cb5ec551 Mon Sep 17 00:00:00 2001 From: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com> Date: Thu, 7 May 2026 15:03:36 -0700 Subject: [PATCH 1/2] refactor: small code quality pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - backend: information_schema page size constant, dedupe HTTP→Ibis errors, parse_qsl import, version from importlib.metadata - http: poll sleep respects deadline, guard missing columns, typed _safe_call - types: trim trailing blank lines - ruff format/import order --- src/ibis_hotdata/backend.py | 66 ++++++++++++++++++++++++------------- src/ibis_hotdata/http.py | 26 +++++++++++---- src/ibis_hotdata/types.py | 2 -- 3 files changed, 63 insertions(+), 31 deletions(-) diff --git a/src/ibis_hotdata/backend.py b/src/ibis_hotdata/backend.py index a2487e6..de0a290 100644 --- a/src/ibis_hotdata/backend.py +++ b/src/ibis_hotdata/backend.py @@ -16,14 +16,12 @@ from __future__ import annotations import contextlib -import urllib.parse from collections.abc import Iterable, Mapping from functools import cached_property +from importlib.metadata import PackageNotFoundError +from importlib.metadata import version as pkg_version from typing import TYPE_CHECKING, Any -from urllib.parse import ParseResult, unquote_plus - -import sqlglot as sg -import sqlglot.expressions as sge +from urllib.parse import ParseResult, parse_qsl, unquote_plus import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com @@ -31,13 +29,27 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir - -from ibis.backends import CanListCatalog, CanListDatabase, HasCurrentCatalog, HasCurrentDatabase, NoExampleLoader +import sqlglot as sg +import sqlglot.expressions as sge +from ibis.backends import ( + CanListCatalog, + CanListDatabase, + HasCurrentCatalog, + HasCurrentDatabase, + NoExampleLoader, +) from ibis.backends.sql import SQLBackend from ibis_hotdata.http import HotdataAPIError, HotdataClient from ibis_hotdata.types import dtype_from_hotdata_sql_type, dtype_from_json_value +_INFORMATION_SCHEMA_PAGE_SIZE = 500 + + +def _ibis_err_from_hotdata(exc: HotdataAPIError) -> com.IbisError: + return com.IbisError(str(exc)) + + if TYPE_CHECKING: from collections.abc import Iterator @@ -98,13 +110,15 @@ def _from_url(self, url: ParseResult, **kwarg_overrides: Any): ``default_schema``, ``prefer_async``. * If ``token`` is omitted, ``urlparse`` password (`user:TOKEN@`) is accepted. """ - q = dict(urllib.parse.parse_qsl(url.query, keep_blank_values=True)) + q = dict(parse_qsl(url.query, keep_blank_values=True)) q.update(kwarg_overrides) netloc = url.netloc path_prefix = url.path.rstrip("/") if not netloc: - raise com.IbisError("hotdata:// URL requires a network location, e.g. hotdata://api.hotdata.dev/") + raise com.IbisError( + "hotdata:// URL requires a network location, e.g. hotdata://api.hotdata.dev/" + ) verify = q.pop("verify_ssl", None) if verify is None: @@ -117,9 +131,7 @@ def _from_url(self, url: ParseResult, **kwarg_overrides: Any): timeout = float(q.pop("timeout", "120")) api_url = q.pop("api_url", None) or ("https://" + netloc + path_prefix) - token = q.pop("token", None) or ( - unquote_plus(url.password) if url.password else None - ) + token = q.pop("token", None) or (unquote_plus(url.password) if url.password else None) workspace_id = q.pop("workspace_id", None) prefer_async_s = q.pop("prefer_async", "false") @@ -236,7 +248,9 @@ def _infer_default_schema(self, connection_id: str) -> str: ) if self._default_schema is not None: if self._default_schema not in schemas: - raise com.IbisInputError(f"Unknown schema {self._default_schema!r} for connection {connection_id!r}") + raise com.IbisInputError( + f"Unknown schema {self._default_schema!r} for connection {connection_id!r}" + ) return self._default_schema if len(schemas) == 1: self._default_schema = schemas[0] @@ -287,7 +301,9 @@ def list_databases( schemas = sorted( { row["schema"] - for row in self._iterate_information_schema({"connection_id": conn}, include_columns=False) + for row in self._iterate_information_schema( + {"connection_id": conn}, include_columns=False + ) } ) return self._filter_with_like(list(schemas), like) @@ -308,7 +324,10 @@ def list_tables( params["schema"] = schema_part tables = sorted( - {row["table"] for row in self._iterate_information_schema(params, include_columns=False)} + { + row["table"] + for row in self._iterate_information_schema(params, include_columns=False) + } ) return self._filter_with_like(tables, like) @@ -318,7 +337,7 @@ def _iterate_information_schema( cursor: str | None = None while True: params: dict[str, Any] = dict(filters) - params["limit"] = 500 + params["limit"] = _INFORMATION_SCHEMA_PAGE_SIZE params["include_columns"] = include_columns if cursor: params["cursor"] = cursor @@ -372,7 +391,7 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: poll_timeout_s=self._poll_timeout_s, ) except HotdataAPIError as exc: - raise com.IbisError(str(exc)) from exc + raise _ibis_err_from_hotdata(exc) from exc cols = data["columns"] nulls = data["nullable"] @@ -404,7 +423,7 @@ def _safe_raw_sql( poll_timeout_s=self._poll_timeout_s, ) except HotdataAPIError as exc: - raise com.IbisError(str(exc)) from exc + raise _ibis_err_from_hotdata(exc) from exc cur = HotdataRowsCursor(payload["rows"]) try: @@ -414,7 +433,6 @@ def _safe_raw_sql( def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: import pandas as pd - from ibis.formats.pandas import PandasData try: @@ -430,7 +448,7 @@ def upload_file(self, data: bytes) -> dict[str, Any]: try: return self._http.upload_file(data) except HotdataAPIError as exc: - raise com.IbisError(str(exc)) from exc + raise _ibis_err_from_hotdata(exc) from exc def create_dataset_from_upload( self, @@ -453,7 +471,7 @@ def create_dataset_from_upload( file_format=file_format, ) except HotdataAPIError as exc: - raise com.IbisError(str(exc)) from exc + raise _ibis_err_from_hotdata(exc) from exc def create_table(self, *_args: Any, **_kwargs: Any) -> ir.Table: raise NotImplementedError( @@ -469,4 +487,8 @@ def _register_in_memory_table(self, _op: ops.InMemoryTable) -> None: @cached_property def version(self) -> str: - return "Hotdata REST API (/v1/query)" + try: + v = pkg_version("ibis-hotdata") + except PackageNotFoundError: + v = "0.0.0" + return f"ibis-hotdata {v} (Hotdata /v1/query)" diff --git a/src/ibis_hotdata/http.py b/src/ibis_hotdata/http.py index beff289..75f05e8 100644 --- a/src/ibis_hotdata/http.py +++ b/src/ibis_hotdata/http.py @@ -3,8 +3,8 @@ from __future__ import annotations import time -from collections.abc import Mapping -from typing import Any, MutableMapping +from collections.abc import Callable, Mapping +from typing import Any, MutableMapping, TypeVar from hotdata import ApiClient, Configuration from hotdata.api import ( @@ -21,6 +21,15 @@ from hotdata.models.async_query_response import AsyncQueryResponse from hotdata.models.query_response import QueryResponse +T = TypeVar("T") + + +def _sleep_until(deadline: float, interval: float) -> None: + """Sleep up to ``interval`` s but never past ``deadline`` (cleaner timeout behavior).""" + remaining = deadline - time.monotonic() + if remaining > 0: + time.sleep(min(interval, remaining)) + class HotdataAPIError(Exception): def __init__(self, message: str, *, status_code: int | None = None, body: Any = None): @@ -53,7 +62,9 @@ def __init__( verify_ssl: bool | str = True, ) -> None: host = api_url.rstrip("/") - conf = Configuration(host=host, api_key=token, workspace_id=workspace_id, session_id=session_id) + conf = Configuration( + host=host, api_key=token, workspace_id=workspace_id, session_id=session_id + ) if verify_ssl is False: conf.verify_ssl = False elif isinstance(verify_ssl, str): @@ -78,7 +89,7 @@ def close(self) -> None: if pool is not None: pool.clear() - def _safe_call(self, fn: Any, /, *args: Any, **kwargs: Any) -> Any: + def _safe_call(self, fn: Callable[..., T], /, *args: Any, **kwargs: Any) -> T: try: return fn(*args, _request_timeout=self._timeout, **kwargs) except ApiException as exc: @@ -130,7 +141,7 @@ def execute_query( return self._poll_result_ready( result_id, deadline=deadline, poll_interval_s=poll_interval_s ) - time.sleep(poll_interval_s) + _sleep_until(deadline, poll_interval_s) raise HotdataAPIError("Timeout waiting for asynchronous query") raise HotdataAPIError("Unexpected query response type") @@ -165,12 +176,13 @@ def _poll_result_ready( raise HotdataAPIError(d.get("error_message") or "Result failed") if st == "ready" or (d.get("rows") is not None and d.get("columns")): return self._normalize_result_payload(d) - time.sleep(poll_interval_s) + _sleep_until(deadline, poll_interval_s) raise HotdataAPIError("Timeout waiting for query result payload") @staticmethod def _normalize_result_payload(data: MutableMapping[str, Any]) -> dict[str, Any]: - columns = list(data["columns"]) + raw = data.get("columns") + columns = list(raw) if raw is not None else [] nullable = list(data.get("nullable") or []) if len(nullable) < len(columns): nullable.extend([True] * (len(columns) - len(nullable))) diff --git a/src/ibis_hotdata/types.py b/src/ibis_hotdata/types.py index a879264..71f76fd 100644 --- a/src/ibis_hotdata/types.py +++ b/src/ibis_hotdata/types.py @@ -39,5 +39,3 @@ def dtype_from_json_value(value: Any) -> dt.DataType | None: return dt.Array(dt.JSON()) return dt.String() - - From cdd6733f779e67ad39de6b5787f2ebd8670f4b2e Mon Sep 17 00:00:00 2001 From: Eddie A Tejeda <669988+eddietejeda@users.noreply.github.com> Date: Fri, 8 May 2026 12:47:57 -0700 Subject: [PATCH 2/2] refactor: make query results Arrow-only Always submit Hotdata queries asynchronously and materialize successful results from the Arrow IPC result endpoint so the backend has one typed execution path. --- README.md | 5 +- examples/_helpers.py | 17 +---- src/ibis_hotdata/backend.py | 76 ++++---------------- src/ibis_hotdata/http.py | 126 ++++++++++++++++++++++++---------- src/ibis_hotdata/types.py | 27 +------- tests/test_hotdata_backend.py | 96 +++++++++++++++++--------- tests/test_hotdata_http.py | 97 ++++++++++++++++---------- tests/test_hotdata_types.py | 26 +------ 8 files changed, 237 insertions(+), 233 deletions(-) diff --git a/README.md b/README.md index 45a9bc7..7c85c93 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,8 @@ con = ibis.hotdata.connect( timeout=120.0, default_connection=None, # Hotdata connection id → Ibis catalog default_schema=None, # remote schema → Ibis database - prefer_async=False, + poll_interval_s=0.25, + poll_timeout_s=600.0, ) ``` @@ -41,7 +42,7 @@ con = ibis.connect( **Mapping:** Ibis **catalog** = Hotdata connection id; **database** = remote schema; **table** = table name. SQL references look like `connection.schema.table`. With a single connection and schema, defaults are inferred; otherwise set `default_connection` / `default_schema` or qualify `con.table(..., database=(conn_id, schema))`. -**Execution:** SQL is compiled with Ibis’s **Postgres** SQLGlot compiler. The client uses `POST /v1/query`; with `prefer_async=True` it follows `202` and polls query-run and result endpoints until rows are ready. Tuning: `poll_interval_s`, `poll_timeout_s` on `connect()`. +**Execution:** SQL is compiled with Ibis’s **Postgres** SQLGlot compiler. The client submits queries asynchronously with `POST /v1/query`, polls `GET /v1/query-runs/{id}`, then downloads ready results as Arrow IPC from `GET /v1/results/{id}`. Tuning: `poll_interval_s`, `poll_timeout_s` on `connect()`. **Types:** Typed tables come from Hotdata’s information schema. `con.sql(...)` types are inferred from a small preview query; see [Hotdata SQL](https://www.hotdata.dev/docs/sql) for server behavior. diff --git a/examples/_helpers.py b/examples/_helpers.py index dffefd2..8b9e887 100644 --- a/examples/_helpers.py +++ b/examples/_helpers.py @@ -175,17 +175,11 @@ def parser(description: str) -> argparse.ArgumentParser: action="store_true", help="Disable TLS verification (dev only)", ) - p.add_argument( - "--prefer-async", - action="store_true", - help="Prefer async POST /v1/query", - ) p.add_argument("--timeout", type=float, default=120.0) p.add_argument( "--default-connection", dest="default_connection", - default=os.environ.get("HOTDATA_DEFAULT_CONNECTION") - or DEFAULT_TPCH_CONNECTION, + default=os.environ.get("HOTDATA_DEFAULT_CONNECTION") or DEFAULT_TPCH_CONNECTION, help=f"Connection id (= Ibis catalog). Env HOTDATA_DEFAULT_CONNECTION. Default {DEFAULT_TPCH_CONNECTION!r}.", ) p.add_argument( @@ -200,11 +194,7 @@ def parser(description: str) -> argparse.ArgumentParser: def parsed_args(parser: argparse.ArgumentParser) -> argparse.Namespace: ns = parser.parse_args() if not ns.token.strip() or not ns.workspace_id.strip(): - parser.error( - "Set HOTDATA_TOKEN and HOTDATA_WORKSPACE_ID, or pass --token and --workspace." - ) - if os.environ.get("HOTDATA_PREFER_ASYNC", "").lower() in ("1", "true", "yes"): - ns.prefer_async = True + parser.error("Set HOTDATA_TOKEN and HOTDATA_WORKSPACE_ID, or pass --token and --workspace.") normalize_tpch_defaults(ns) return ns @@ -217,7 +207,6 @@ def connect_kwargs(ns: argparse.Namespace, **extras) -> dict: "token": ns.token.strip(), "workspace_id": ns.workspace_id.strip(), "timeout": ns.timeout, - "prefer_async": ns.prefer_async, "verify_ssl": False if getattr(ns, "insecure", False) else True, } if ns.session_id: @@ -256,7 +245,5 @@ def hotdata_connect_uri(ns: argparse.Namespace) -> str: qs["default_connection"] = dc if ds: qs["default_schema"] = ds - if ns.prefer_async: - qs["prefer_async"] = "true" q = urllib.parse.urlencode(qs) return f"hotdata://{api_host(ns.api_url)}/?{q}" diff --git a/src/ibis_hotdata/backend.py b/src/ibis_hotdata/backend.py index de0a290..7737f2d 100644 --- a/src/ibis_hotdata/backend.py +++ b/src/ibis_hotdata/backend.py @@ -41,7 +41,7 @@ from ibis.backends.sql import SQLBackend from ibis_hotdata.http import HotdataAPIError, HotdataClient -from ibis_hotdata.types import dtype_from_hotdata_sql_type, dtype_from_json_value +from ibis_hotdata.types import dtype_from_hotdata_sql_type _INFORMATION_SCHEMA_PAGE_SIZE = 500 @@ -56,32 +56,6 @@ def _ibis_err_from_hotdata(exc: HotdataAPIError) -> com.IbisError: import pandas as pd -class HotdataRowsCursor: - """DB-API–like cursor backed by prefetched rows (used by `_fetch_from_cursor`).""" - - def __init__(self, rows: list) -> None: - self._rows = rows - self._idx = 0 - - def fetchmany(self, size: int = 1024) -> list: - start = self._idx - end = min(self._idx + size, len(self._rows)) - self._idx = end - return [tuple(r) for r in self._rows[start:end]] - - def fetchall(self) -> list: - return [tuple(r) for r in self._rows[self._idx :]] - - def close(self) -> None: - self._idx = len(self._rows) - - def __iter__(self) -> Iterator[tuple]: - while self._idx < len(self._rows): - row = self._rows[self._idx] - self._idx += 1 - yield tuple(row) - - class Backend( SQLBackend, CanListCatalog, @@ -107,7 +81,7 @@ def _from_url(self, url: ParseResult, **kwarg_overrides: Any): * Base URL defaults to ``https://{host}`` plus optional leading ``path``. * Query string may include ``token``, ``workspace_id``, ``session_id``, ``timeout``, ``verify_ssl`` (``true`` / ``false``), ``default_connection``, - ``default_schema``, ``prefer_async``. + ``default_schema``, ``poll_interval_s``, ``poll_timeout_s``. * If ``token`` is omitted, ``urlparse`` password (`user:TOKEN@`) is accepted. """ q = dict(parse_qsl(url.query, keep_blank_values=True)) @@ -134,8 +108,6 @@ def _from_url(self, url: ParseResult, **kwarg_overrides: Any): token = q.pop("token", None) or (unquote_plus(url.password) if url.password else None) workspace_id = q.pop("workspace_id", None) - prefer_async_s = q.pop("prefer_async", "false") - kwargs = dict( api_url=api_url, token=token, @@ -145,7 +117,6 @@ def _from_url(self, url: ParseResult, **kwarg_overrides: Any): verify_ssl=verify_ssl, default_connection=q.pop("default_connection", None), default_schema=q.pop("default_schema", None), - prefer_async=str(prefer_async_s).lower() in ("true", "1", "yes"), poll_interval_s=float(q.pop("poll_interval_s", "0.25")), poll_timeout_s=float(q.pop("poll_timeout_s", "600")), ) @@ -169,12 +140,14 @@ def do_connect( verify_ssl: bool | str = True, default_connection: str | None = None, default_schema: str | None = None, - prefer_async: bool = False, poll_interval_s: float = 0.25, poll_timeout_s: float = 600.0, ) -> None: """Create an Ibis client for a Hotdata workspace. + Query execution always uses Hotdata's async path and downloads ready + results as Arrow IPC from ``GET /v1/results/{id}``. + Parameters ---------- api_url @@ -196,8 +169,6 @@ def do_connect( default_schema Optional default **database** (remote schema name). If omitted and only one schema exists for the default connection, it is chosen automatically. - prefer_async - When True, requests ``async: true`` on ``POST /v1/query`` (with polling). poll_interval_s Sleep between ``GET /v1/query-runs/{id}`` polls. poll_timeout_s @@ -206,7 +177,6 @@ def do_connect( self.disconnect() self._default_connection = default_connection self._default_schema = default_schema - self._prefer_async = prefer_async self._poll_interval_s = poll_interval_s self._poll_timeout_s = poll_timeout_s @@ -386,62 +356,40 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: try: data = self._http.execute_query( preview_sql, - prefer_async=self._prefer_async, poll_interval_s=self._poll_interval_s, poll_timeout_s=self._poll_timeout_s, ) except HotdataAPIError as exc: raise _ibis_err_from_hotdata(exc) from exc - cols = data["columns"] - nulls = data["nullable"] - row0 = data["rows"][0] if data.get("rows") else None - mapping: dict[str, dt.DataType] = {} - for i, name in enumerate(cols): - null = bool(nulls[i]) if i < len(nulls) else True - if row0 is not None and i < len(row0): - inferred = dtype_from_json_value(row0[i]) - if inferred is not None: - mapping[name] = inferred.copy(nullable=null) - continue - mapping[name] = dt.String(nullable=null) - return sch.Schema(mapping) + from ibis.formats.pyarrow import PyArrowSchema + + return PyArrowSchema.to_ibis(data["pa_table"].schema) @contextlib.contextmanager def _safe_raw_sql( self, query: str | sge.Expression, - ) -> Iterator[HotdataRowsCursor]: + ) -> Iterator[Any]: if not isinstance(query, str): query = query.sql(dialect=self.compiler.dialect, pretty=True) try: payload = self._http.execute_query( query, - prefer_async=self._prefer_async, poll_interval_s=self._poll_interval_s, poll_timeout_s=self._poll_timeout_s, ) except HotdataAPIError as exc: raise _ibis_err_from_hotdata(exc) from exc - cur = HotdataRowsCursor(payload["rows"]) - try: - yield cur - finally: - cur.close() + yield payload["pa_table"] def _fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame: - import pandas as pd from ibis.formats.pandas import PandasData - try: - df = pd.DataFrame.from_records(iter(cursor), columns=schema.names, coerce_float=True) - except Exception: - cursor.close() - raise - df = PandasData.convert_table(df, schema) - return df + df = cursor.to_pandas() + return PandasData.convert_table(df, schema) def upload_file(self, data: bytes) -> dict[str, Any]: """POST ``/v1/files``; returns the upload record (use ``id`` with :meth:`create_dataset_from_upload`).""" diff --git a/src/ibis_hotdata/http.py b/src/ibis_hotdata/http.py index 75f05e8..b76ac71 100644 --- a/src/ibis_hotdata/http.py +++ b/src/ibis_hotdata/http.py @@ -2,9 +2,15 @@ from __future__ import annotations +import io +import json import time from collections.abc import Callable, Mapping -from typing import Any, MutableMapping, TypeVar +from typing import Any, TypeVar + +import pyarrow as pa +import pyarrow_hotfix # noqa: F401 +import pyarrow.ipc as pa_ipc from hotdata import ApiClient, Configuration from hotdata.api import ( @@ -19,10 +25,12 @@ from hotdata.exceptions import ApiException from hotdata.models import CreateDatasetRequest, DatasetSource, QueryRequest, UploadDatasetSource from hotdata.models.async_query_response import AsyncQueryResponse -from hotdata.models.query_response import QueryResponse T = TypeVar("T") +# Matches Hotdata / runtimedb ``GET /v1/results/{{id}}`` Arrow responses. +APPLICATION_ARROW_STREAM = "application/vnd.apache.arrow.stream" + def _sleep_until(deadline: float, interval: float) -> None: """Sleep up to ``interval`` s but never past ``deadline`` (cleaner timeout behavior).""" @@ -48,6 +56,15 @@ def _from_api_exception(exc: ApiException) -> HotdataAPIError: return HotdataAPIError(msg.strip(), status_code=exc.status, body=exc.body) +def _ipc_stream_bytes_to_table(data: bytes) -> pa.Table: + with pa_ipc.open_stream(io.BytesIO(data)) as reader: + return reader.read_all() + + +def _json_utf8(obj: bytes) -> Any: + return json.loads(obj.decode("utf-8")) + + class HotdataClient: """Thin wrapper around the SDK used by the Ibis backend.""" @@ -117,15 +134,12 @@ def execute_query( self, sql: str, *, - prefer_async: bool = False, async_after_ms: int | None = None, poll_interval_s: float = 0.25, poll_timeout_s: float = 600.0, ) -> dict[str, Any]: - req = QueryRequest(sql=sql, var_async=prefer_async, async_after_ms=async_after_ms) + req = QueryRequest(sql=sql, var_async=True, async_after_ms=async_after_ms) out = self._safe_call(self._query.query, req) - if isinstance(out, QueryResponse): - return self._normalize_result_payload(out.model_dump(by_alias=True)) if isinstance(out, AsyncQueryResponse): query_run_id = out.query_run_id deadline = time.monotonic() + poll_timeout_s @@ -138,8 +152,10 @@ def execute_query( result_id = qr.result_id if result_id is None: raise HotdataAPIError("succeeded query run missing result_id") - return self._poll_result_ready( - result_id, deadline=deadline, poll_interval_s=poll_interval_s + return self._poll_result_arrow( + result_id, + deadline=deadline, + poll_interval_s=poll_interval_s, ) _sleep_until(deadline, poll_interval_s) raise HotdataAPIError("Timeout waiting for asynchronous query") @@ -165,37 +181,77 @@ def create_dataset_from_upload( resp = self._safe_call(self._datasets.create_dataset, req) return resp.model_dump(by_alias=True, mode="json") - def _poll_result_ready( - self, result_id: str, *, deadline: float, poll_interval_s: float + def _poll_result_arrow( + self, + result_id: str, + *, + deadline: float, + poll_interval_s: float, ) -> dict[str, Any]: + """Poll ``GET /v1/results/{{id}}`` with ``Accept: application/vnd.apache.arrow.stream``.""" while time.monotonic() < deadline: - res = self._safe_call(self._results.get_result, result_id) - d = res.model_dump(by_alias=True) - st = d.get("status") - if st == "failed": - raise HotdataAPIError(d.get("error_message") or "Result failed") - if st == "ready" or (d.get("rows") is not None and d.get("columns")): - return self._normalize_result_payload(d) - _sleep_until(deadline, poll_interval_s) - raise HotdataAPIError("Timeout waiting for query result payload") - - @staticmethod - def _normalize_result_payload(data: MutableMapping[str, Any]) -> dict[str, Any]: - raw = data.get("columns") - columns = list(raw) if raw is not None else [] - nullable = list(data.get("nullable") or []) - if len(nullable) < len(columns): - nullable.extend([True] * (len(columns) - len(nullable))) - elif len(nullable) > len(columns): - nullable = nullable[: len(columns)] + try: + raw = self._results.get_result_without_preload_content( + result_id, + _headers={"Accept": APPLICATION_ARROW_STREAM}, + _request_timeout=self._timeout, + ) + except ApiException as exc: + raise _from_api_exception(exc) from exc + body = raw.read() + status = raw.status + ctype = (raw.headers.get("Content-Type") or "").split(";")[0].strip().lower() + + if status == 200 and ctype == APPLICATION_ARROW_STREAM.lower(): + table = _ipc_stream_bytes_to_table(body) + return self._arrow_payload_from_table(table, result_id=result_id) + if status == 202: + _sleep_until(deadline, poll_interval_s) + continue + + if status == 409: + d = _json_utf8(body) if body else {} + raise HotdataAPIError( + d.get("error_message") or "Result failed", + status_code=409, + body=d, + ) + + if status == 404: + d = _json_utf8(body) if body else {} + raise HotdataAPIError( + d.get("detail") or f"Result {result_id!r} not found", + status_code=404, + body=d, + ) + + raise HotdataAPIError( + f"Unexpected GET /v1/results/{result_id} status {status}", + status_code=status, + body=body, + ) + + raise HotdataAPIError("Timeout waiting for Arrow query result") + + def _arrow_payload_from_table( + self, + table: pa.Table, + *, + result_id: str, + ) -> dict[str, Any]: + sch = table.schema + columns = sch.names + nullable = [sch.field(i).nullable for i in range(len(columns))] return { + "format": "arrow", + "pa_table": table, "columns": columns, "nullable": nullable, - "rows": list(data["rows"]) if data.get("rows") is not None else [], - "row_count": data.get("row_count"), - "execution_time_ms": data.get("execution_time_ms"), - "query_run_id": data.get("query_run_id"), - "result_id": data.get("result_id"), - "warning": data.get("warning"), + "rows": [], + "result_id": result_id, + "row_count": table.num_rows, + "execution_time_ms": None, + "query_run_id": None, + "warning": None, } diff --git a/src/ibis_hotdata/types.py b/src/ibis_hotdata/types.py index 71f76fd..c3a4bad 100644 --- a/src/ibis_hotdata/types.py +++ b/src/ibis_hotdata/types.py @@ -1,10 +1,7 @@ -"""Map Hotdata metadata and JSON cells to Ibis dtypes.""" +"""Map Hotdata metadata to Ibis dtypes.""" from __future__ import annotations -from decimal import Decimal -from typing import Any - import ibis.expr.datatypes as dt from ibis.backends.sql.datatypes import PostgresType @@ -17,25 +14,3 @@ def dtype_from_hotdata_sql_type(sql_type: str | None, *, nullable: bool) -> dt.D return PostgresType.from_string(sql_type.strip(), nullable=nullable) except Exception: return dt.String(nullable=nullable) - - -def dtype_from_json_value(value: Any) -> dt.DataType | None: - """Infer an Ibis dtype from a deserialized JSON cell (no nullability signal).""" - if value is None: - return None - if isinstance(value, bool): - return dt.Boolean() - if isinstance(value, int): - return dt.Int64() - if isinstance(value, float): - return dt.Float64() - if isinstance(value, Decimal): - return dt.Decimal(precision=None, scale=None) - if isinstance(value, str): - return dt.String() - if isinstance(value, dict): - return dt.JSON() - if isinstance(value, list): - return dt.Array(dt.JSON()) - - return dt.String() diff --git a/tests/test_hotdata_backend.py b/tests/test_hotdata_backend.py index caef291..771d23b 100644 --- a/tests/test_hotdata_backend.py +++ b/tests/test_hotdata_backend.py @@ -1,9 +1,12 @@ from __future__ import annotations +import io import json import ibis import ibis.common.exceptions as com +import pyarrow as pa +import pyarrow.ipc as ipc import pytest from werkzeug.wrappers import Request, Response @@ -25,37 +28,52 @@ ] +def arrow_stream(table: pa.Table) -> bytes: + sink = io.BytesIO() + with ipc.new_stream(sink, table.schema) as writer: + writer.write_table(table) + return sink.getvalue() + + def test_connect_via_url(httpserver: HTTPServer, srv: str): - url = ( - f"hotdata://127.0.0.1:{httpserver.port}" - "?token=tok&workspace_id=ws_demo&verify_ssl=false" - ) + url = f"hotdata://127.0.0.1:{httpserver.port}?token=tok&workspace_id=ws_demo&verify_ssl=false" con = ibis.connect(url) assert getattr(con, "name", "") == "hotdata" def test_connect_via_url_password_token(httpserver: HTTPServer): token = "secret_pass" - url = ( - f"hotdata://u:{token}@127.0.0.1:{httpserver.port}" - "/?workspace_id=ws_pw&verify_ssl=false" - ) + url = f"hotdata://u:{token}@127.0.0.1:{httpserver.port}/?workspace_id=ws_pw&verify_ssl=false" con = ibis.connect(url) assert getattr(con, "name", "") == "hotdata" def test_sql_execution(httpserver: HTTPServer, srv: str): - body = { - "columns": ["x"], - "nullable": [False], - "rows": [[1]], - "row_count": 1, - "execution_time_ms": 3, - "query_run_id": "qr-sync", - "result_id": None, - "warning": None, - } - httpserver.expect_request("/v1/query", method="POST").respond_with_json(body) + httpserver.expect_request("/v1/query", method="POST").respond_with_json( + { + "query_run_id": "run1", + "status": "queued", + "status_url": "http://poll", + "reason": None, + }, + status=202, + ) + httpserver.expect_request("/v1/query-runs/run1").respond_with_json( + { + "created_at": "2026-01-01T00:00:00Z", + "snapshot_id": "snap", + "sql_hash": "h", + "sql_text": "select 1", + "status": "succeeded", + "result_id": "res1", + "id": "run1", + } + ) + httpserver.expect_request("/v1/results/res1").respond_with_data( + arrow_stream(pa.table({"x": [1]})), + status=200, + content_type="application/vnd.apache.arrow.stream", + ) con = ibis.hotdata.connect( api_url=srv, @@ -252,22 +270,38 @@ def test_ambiguous_default_connection(httpserver: HTTPServer, srv: str): def test_x_session_header_on_query(httpserver: HTTPServer, srv: str): seen: list[str | None] = [] - sync = { - "columns": ["n"], - "nullable": [True], - "rows": [[0]], - "row_count": 1, - "execution_time_ms": 1, - "query_run_id": "qr", - "result_id": None, - "warning": None, - } - def on_post(req: Request) -> Response: seen.append(req.headers.get("X-Session-Id")) - return Response(json.dumps(sync), status=200, content_type="application/json") + return Response( + json.dumps( + { + "query_run_id": "run1", + "status": "queued", + "status_url": "http://poll", + "reason": None, + } + ), + status=202, + content_type="application/json", + ) httpserver.expect_request("/v1/query", method="POST").respond_with_handler(on_post) + httpserver.expect_request("/v1/query-runs/run1").respond_with_json( + { + "created_at": "2026-01-01T00:00:00Z", + "snapshot_id": "snap", + "sql_hash": "h", + "sql_text": "select 0", + "status": "succeeded", + "result_id": "res1", + "id": "run1", + } + ) + httpserver.expect_request("/v1/results/res1").respond_with_data( + arrow_stream(pa.table({"n": [0]})), + status=200, + content_type="application/vnd.apache.arrow.stream", + ) con = ibis.hotdata.connect( api_url=srv, diff --git a/tests/test_hotdata_http.py b/tests/test_hotdata_http.py index 6824391..044d4f2 100644 --- a/tests/test_hotdata_http.py +++ b/tests/test_hotdata_http.py @@ -1,12 +1,15 @@ from __future__ import annotations +import io import json +import pyarrow as pa +import pyarrow.ipc as ipc import pytest from werkzeug.wrappers import Request, Response from pytest_httpserver import HTTPServer -from ibis_hotdata.http import HotdataAPIError, HotdataClient +from ibis_hotdata.http import APPLICATION_ARROW_STREAM, HotdataAPIError, HotdataClient _QR_META = { @@ -36,18 +39,17 @@ def test_execute_query_async_poll(httpserver: HTTPServer): } ) - preview = { - "columns": ["n"], - "nullable": [True], - "rows": [[42]], - "row_count": 1, - "execution_time_ms": 1, - "query_run_id": "qr", - "result_id": "res1", - "warning": None, - "status": "ready", - } - httpserver.expect_oneshot_request("/v1/results/res1").respond_with_json(preview) + table = pa.table({"n": [42]}) + sink = io.BytesIO() + with ipc.new_stream(sink, table.schema) as writer: + writer.write_table(table) + arrow_blob = sink.getvalue() + + httpserver.expect_oneshot_request("/v1/results/res1").respond_with_data( + arrow_blob, + status=200, + content_type=APPLICATION_ARROW_STREAM, + ) client = HotdataClient( api_url=httpserver.url_for("/").rstrip("/"), @@ -57,17 +59,16 @@ def test_execute_query_async_poll(httpserver: HTTPServer): ) body = client.execute_query( "select 41+1", - prefer_async=True, poll_interval_s=0, poll_timeout_s=5, ) client.close() - assert body["columns"] == ["n"] - assert body["rows"] == [[42]] + assert body["format"] == "arrow" + assert body["pa_table"].to_pydict() == {"n": [42]} -def test_sync_error_raises(httpserver: HTTPServer): +def test_query_error_raises(httpserver: HTTPServer): httpserver.expect_oneshot_request("/v1/query", method="POST").respond_with_json( {"detail": "bad"}, status=500 ) @@ -78,31 +79,51 @@ def test_sync_error_raises(httpserver: HTTPServer): verify_ssl=False, ) with pytest.raises(HotdataAPIError): - client.execute_query("select 1", prefer_async=False) + client.execute_query("select 1") client.close() -def test_sync_200_pad_shorter_nullable_array(httpserver: HTTPServer): - body = { - "columns": ["a", "b", "c"], - "nullable": [False], - "rows": [[1, 2, 3]], - "row_count": 1, - "execution_time_ms": 0, - "query_run_id": "qr", - "result_id": None, - "warning": None, - } - httpserver.expect_oneshot_request("/v1/query", method="POST").respond_with_json(body) +def test_result_arrow_poll_handles_accepted_result(httpserver: HTTPServer): + httpserver.expect_oneshot_request("/v1/query", method="POST").respond_with_json( + { + "query_run_id": "run1", + "status": "queued", + "status_url": "http://poll", + "reason": None, + }, + status=202, + ) + httpserver.expect_oneshot_request("/v1/query-runs/run1").respond_with_json( + { + **_QR_META, + "status": "succeeded", + "result_id": "res1", + "id": "run1", + } + ) + httpserver.expect_oneshot_request("/v1/results/res1").respond_with_json( + {"result_id": "res1", "status": "processing"}, + status=202, + ) + + table = pa.table({"n": [42]}) + sink = io.BytesIO() + with ipc.new_stream(sink, table.schema) as writer: + writer.write_table(table) + httpserver.expect_oneshot_request("/v1/results/res1").respond_with_data( + sink.getvalue(), + status=200, + content_type=APPLICATION_ARROW_STREAM, + ) + client = HotdataClient( api_url=httpserver.url_for("/").rstrip("/"), token="t", workspace_id="w", verify_ssl=False, ) - out = client.execute_query("select 1", prefer_async=False) - assert len(out["nullable"]) == 3 - assert out["nullable"][0] is False + out = client.execute_query("select 1", poll_interval_s=0, poll_timeout_s=5) + assert out["pa_table"].to_pydict() == {"n": [42]} def test_async_query_run_failure(httpserver: HTTPServer): @@ -125,7 +146,11 @@ def test_async_query_run_failure(httpserver: HTTPServer): verify_ssl=False, ) with pytest.raises(HotdataAPIError, match="boom"): - client.execute_query("select junk", prefer_async=True, poll_interval_s=0, poll_timeout_s=2) + client.execute_query( + "select junk", + poll_interval_s=0, + poll_timeout_s=2, + ) client.close() @@ -172,7 +197,9 @@ def on_dataset(req: Request) -> Response: } return Response(json.dumps(payload), status=201, content_type="application/json") - httpserver.expect_oneshot_request("/v1/datasets", method="POST").respond_with_handler(on_dataset) + httpserver.expect_oneshot_request("/v1/datasets", method="POST").respond_with_handler( + on_dataset + ) client = HotdataClient( api_url=httpserver.url_for("/").rstrip("/"), diff --git a/tests/test_hotdata_types.py b/tests/test_hotdata_types.py index c4f3f7b..df16d7d 100644 --- a/tests/test_hotdata_types.py +++ b/tests/test_hotdata_types.py @@ -1,12 +1,10 @@ from __future__ import annotations -from decimal import Decimal - import pytest import ibis.expr.datatypes as dt -from ibis_hotdata.types import dtype_from_hotdata_sql_type, dtype_from_json_value +from ibis_hotdata.types import dtype_from_hotdata_sql_type @pytest.mark.parametrize( @@ -38,25 +36,3 @@ def test_dtype_from_hotdata_vendor_name_maps_or_string_fallback(): def test_dtype_from_hotdata_malformed_fallback_string(): out = dtype_from_hotdata_sql_type('"', nullable=False) assert isinstance(out, dt.String) - -@pytest.mark.parametrize( - ("value", "expected_cls"), - [ - (True, dt.Boolean), - (42, dt.Int64), - (3.14, dt.Float64), - (Decimal("1.23"), dt.Decimal), - ("hi", dt.String), - ], -) -def test_dtype_from_json_primitive(value, expected_cls): - out = dtype_from_json_value(value) - assert isinstance(out, expected_cls) - - -def test_dtype_from_json_null_and_container(): - assert dtype_from_json_value(None) is None - coll = dtype_from_json_value([1]) - assert isinstance(coll, dt.Array) - blob = dtype_from_json_value({"a": 1}) - assert isinstance(blob, dt.JSON)