diff --git a/README.md b/README.md index 31a109d..ba6c02a 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,16 @@ Marimo UI helpers for [Hotdata](https://hotdata.dev): run SQL from a notebook, browse catalog metadata, and render results as tables. +## Features + +- **Workspace-aware setup** — build a `HotdataClient` from environment variables, or use `workspace_selector_from_env()` to choose a workspace interactively when no workspace is pinned. +- **Connection health** — show a compact status callout with API, workspace, and optional sandbox context. +- **Catalog browsing** — browse Hotdata connections, schemas, tables, and columns from Marimo UI controls. +- **SQL editor widget** — run SQL against Hotdata, cache the latest successful result, and render results in downstream reactive cells. +- **Native `mo.sql` engine** — register `HotdataMarimoEngine` so Marimo SQL cells can execute through a live `HotdataClient` with `engine=client`. +- **Result display helpers** — render query results, recent results, and run history as notebook-friendly UI. +- **Marimo UI aliases** — importing `hotdata_marimo` attaches helpers such as `mo.ui.hotdata_sql_editor` and `mo.ui.hotdata_table_browser` for discoverability. + ## Install ```bash @@ -39,13 +49,45 @@ Importing `hotdata_marimo` registers discoverability aliases on Marimo’s UI na Use `hm.connection_status(client)` (or `mo.ui.hotdata_connection_status(client)`) for a small API/workspace health callout. +## Marimo SQL Cells + +Register the Hotdata SQL engine once during setup, then pass a `HotdataClient` to Marimo SQL cells: + +```python +import hotdata_marimo as hm + +hm.register_hotdata_sql_engine() +client = hm.from_env() +``` + +```python +_df = mo.sql( + """ + SELECT 1 AS example_value + """, + engine=client, +) +``` + +The engine also exposes Hotdata catalog metadata to Marimo's data-source UI. Hotdata connections are labeled **Hotdata** in the SQL connection picker. + ## Two-cell pattern Keep the editor in one cell and consume `editor.result` in another. The editor caches the last successful run so downstream cells do not re-query the API on every refresh; click **Run on Hotdata** again after you change SQL. While a query is running, a Marimo status spinner is shown. Marimo only shows **what you `return` from a cell**. Calling `mo.vstack(...)` or `hm.query_result(...)` without returning it produces no visible output. -See `examples/hotdata_basic.py` for a full notebook: five Python cells (`mo.vstack` for **controls only**, then a separate cell `return hm.query_result(editor.result)` so results show immediately — **avoid** `mo.lazy` here: it only renders after the block scrolls into view, which looks like an empty cell). If Marimo shows **empty cells**, quit and remove `examples/__marimo__/` so the UI reloads from the `.py` file only. +See `examples/demo.py` for a full runnable notebook flow. + +## Examples + +- `examples/demo.py` — end-to-end browser + editor + result rendering flow. + +Run: + +```bash +uv run marimo edit examples/demo.py --no-token +``` ## Layout @@ -58,7 +100,7 @@ This package depends on [**hotdata-runtime**](https://github.com/hotdata-dev/hot ```bash uv sync --locked uv run pytest -marimo edit examples/hotdata_basic.py --no-token +marimo edit examples/demo.py --no-token ``` To pin **hotdata-runtime** from Git instead of the sibling path, remove the `[tool.uv.sources]` block, set the dependency line as needed, and run `uv lock` again. diff --git a/examples/demo.py b/examples/demo.py new file mode 100644 index 0000000..b89af53 --- /dev/null +++ b/examples/demo.py @@ -0,0 +1,95 @@ +import marimo + +__generated_with = "0.23.5" +app = marimo.App() + + +@app.cell +def _(): + import os + + import marimo as mo + + import hotdata_marimo as hm + + hm.register_hotdata_sql_engine() + return hm, mo, os + + +@app.cell +def _(hm, mo, os): + mo.stop( + not os.environ.get("HOTDATA_API_KEY"), + mo.callout( + mo.md( + "Add **HOTDATA_API_KEY** to your environment " + "to run this example." + ), + kind="warn", + ), + ) + workspace = hm.workspace_selector_from_env() + return (workspace,) + + +@app.cell +def _(hm, workspace): + client = workspace.client + status = hm.connection_status(client) + browser = hm.table_browser(client) + editor = hm.sql_editor( + client, + default_sql="SELECT 1 AS ok", + ) + recent = hm.recent_results(client, limit=20) + history = hm.run_history(client, limit=10) + return browser, client, editor, history, recent, status, workspace + + +@app.cell +def _(browser, editor, mo, recent, status, workspace): + return mo.vstack( + [ + workspace.ui, + status, + browser.ui, + editor.ui, + recent.ui, + ], + gap=2, + ) + + +@app.cell +def _(history): + return history + + +@app.cell +def _(editor, hm): + # Explicitly touch nested widget values so Marimo reruns this cell on clicks. + _run = editor.run.value + _rerun = editor.rerun.value + _clear = editor.clear.value + return hm.query_result(editor.result), _clear, _rerun, _run + + +@app.cell +def _(hm, recent): + _selected = recent.pick.value + return hm.query_result(recent.result, label="Recent result"), _selected + + +@app.cell +def _(client, mo): + _df = mo.sql( + """ + SELECT 1 AS example_value + """, + engine=client, + ) + return + + +if __name__ == "__main__": + app.run() diff --git a/examples/hotdata_basic.py b/examples/hotdata_basic.py deleted file mode 100644 index 18dcd67..0000000 --- a/examples/hotdata_basic.py +++ /dev/null @@ -1,76 +0,0 @@ -import marimo - -__generated_with = "0.23.5" -app = marimo.App() - - -@app.cell -def _(): - import os - - import marimo as mo - - import hotdata_marimo as hm - - return hm, mo, os - - -@app.cell -def _(hm, mo, os): - mo.stop( - not ( - os.environ.get("HOTDATA_API_KEY") - or os.environ.get("HOTDATA_TOKEN") - ), - mo.callout( - mo.md( - "Add **HOTDATA_API_KEY** (or **HOTDATA_TOKEN**) to your environment " - "to run this example." - ), - kind="warn", - ), - ) - client = hm.from_env() - return (client,) - - -@app.cell -def _(client, hm, mo): - id_map = client.connection_id_by_name() - tpch_id = id_map.get("tpch") - mo.stop( - not tpch_id, - mo.callout( - mo.md( - "This example expects a connection named **tpch**. " - "Create it in Hotdata or adjust the name in the notebook." - ), - kind="warn", - ), - ) - browser = hm.table_browser(client, connection_id=tpch_id) - editor = hm.sql_editor( - client, - default_sql="SELECT * FROM tpch.tpch_sf1.nation LIMIT 5", - ) - return browser, editor - - -@app.cell -def _(browser, editor, mo): - mo.vstack([browser.ui, editor.ui], gap=2) - return - - -@app.cell -def _(editor, hm): - # Explicitly touch nested widget values so Marimo reruns this cell on clicks. - _run = editor.run.value - _rerun = editor.rerun.value - _clear = editor.clear.value - hm.query_result(editor.result) - return _clear, _rerun, _run - - -if __name__ == "__main__": - app.run() diff --git a/hotdata_marimo/__init__.py b/hotdata_marimo/__init__.py index 4e8aec4..cc714d7 100644 --- a/hotdata_marimo/__init__.py +++ b/hotdata_marimo/__init__.py @@ -16,6 +16,11 @@ recent_results, run_history, ) +from hotdata_marimo.sql_engine import ( + HotdataMarimoEngine, + register_hotdata_sql_engine, + unregister_hotdata_sql_engine, +) from hotdata_marimo.sql_editor import SqlEditor, sql_editor from hotdata_marimo.table_browser import TableBrowser, connection_picker, table_browser from hotdata_marimo.workspace_selector import WorkspaceSelector, workspace_selector_from_env @@ -23,6 +28,7 @@ __all__ = [ "__version__", "HotdataClient", + "HotdataMarimoEngine", "QueryResult", "RecentResults", "SqlEditor", @@ -39,11 +45,13 @@ "hotdata_workspace_selector", "query_result", "recent_results", + "register_hotdata_sql_engine", + "register_mo_ui_hotdata_aliases", "run_history", "sql_editor", "table_browser", + "unregister_hotdata_sql_engine", "workspace_selector_from_env", - "register_mo_ui_hotdata_aliases", ] hotdata_sql_editor = sql_editor diff --git a/hotdata_marimo/display.py b/hotdata_marimo/display.py index e96d01f..365c03f 100644 --- a/hotdata_marimo/display.py +++ b/hotdata_marimo/display.py @@ -4,9 +4,20 @@ import marimo as mo -from hotdata_runtime.client import HotdataClient -from hotdata_runtime.health import workspace_health_lines -from hotdata_runtime.result import QueryResult +from hotdata_runtime import HotdataClient, QueryResult, workspace_health_lines + + +def _option_map_with_unique_labels( + pairs: list[tuple[str, str]], +) -> dict[str, str]: + counts: dict[str, int] = {} + options: dict[str, str] = {} + for label, value in pairs: + count = counts.get(label, 0) + counts[label] = count + 1 + key = label if count == 0 else f"{label} ({count + 1})" + options[key] = value + return options def query_result( @@ -27,17 +38,18 @@ def query_result( ) else: trunc = None + meta = result.metadata_dict() meta_bits = [] - if result.result_id: - meta_bits.append(f"**result_id** `{result.result_id}`") - if result.query_run_id: - meta_bits.append(f"**query_run_id** `{result.query_run_id}`") - if result.execution_time_ms is not None: - meta_bits.append(f"**execution_time_ms** {result.execution_time_ms}") - if result.warning: - meta_bits.append(f"**warning** {result.warning}") - if result.error_message: - meta_bits.append(f"**error** {result.error_message}") + if meta["result_id"]: + meta_bits.append(f"**result_id** `{meta['result_id']}`") + if meta["query_run_id"]: + meta_bits.append(f"**query_run_id** `{meta['query_run_id']}`") + if meta["execution_time_ms"] is not None: + meta_bits.append(f"**execution_time_ms** {meta['execution_time_ms']}") + if meta["warning"]: + meta_bits.append(f"**warning** {meta['warning']}") + if meta["error_message"]: + meta_bits.append(f"**error** {meta['error_message']}") header = mo.md(" · ".join(meta_bits) if meta_bits else "_No metadata._") df = result.to_pandas() tbl = mo.ui.table( @@ -59,11 +71,12 @@ def query_result( class RecentResults: def __init__(self, client: HotdataClient, *, limit: int = 50) -> None: self._client = client - listing = client.results().list_results(limit=limit, offset=0) - self._results = listing.results - options = { - f"{r.created_at} · {r.status} · {r.id}": r.id for r in self._results - } + self._results = client.list_recent_results(limit=limit, offset=0) + option_pairs = [ + (f"{r.created_at} · {r.status} · {r.result_id}", r.result_id) + for r in self._results + ] + options = _option_map_with_unique_labels(option_pairs) self.pick = mo.ui.dropdown( options=options or {"(no results)": ""}, label="Recent results", @@ -97,7 +110,7 @@ def run_history( limit: int = 20, label: str = "Run history", ): - runs = client.query_runs().list_query_runs(limit=limit).query_runs + runs = client.list_run_history(limit=limit) if not runs: return mo.md("_No query runs returned._") @@ -105,12 +118,11 @@ def run_history( for r in runs: rows.append( { - "created_at": getattr(r, "created_at", None), - "status": getattr(r, "status", None), - "execution_time_ms": getattr(r, "execution_time_ms", None), - "result_id": getattr(r, "result_id", None), - "query_run_id": getattr(r, "id", None) - or getattr(r, "query_run_id", None), + "created_at": r.created_at, + "status": r.status, + "execution_time_ms": r.execution_time_ms, + "result_id": r.result_id, + "query_run_id": r.query_run_id, } ) diff --git a/hotdata_marimo/sql_editor.py b/hotdata_marimo/sql_editor.py index 2fa65a2..590ad86 100644 --- a/hotdata_marimo/sql_editor.py +++ b/hotdata_marimo/sql_editor.py @@ -2,8 +2,7 @@ import marimo as mo -from hotdata_runtime.client import HotdataClient -from hotdata_runtime.result import QueryResult +from hotdata_runtime import HotdataClient, QueryResult class SqlEditor: diff --git a/hotdata_marimo/sql_engine.py b/hotdata_marimo/sql_engine.py new file mode 100644 index 0000000..44985cf --- /dev/null +++ b/hotdata_marimo/sql_engine.py @@ -0,0 +1,375 @@ +"""Marimo ``mo.sql`` engine integration for :class:`~hotdata_runtime.HotdataClient`.""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Any, Literal + +from hotdata_runtime import HotdataClient + +from marimo import _loggers +from marimo._data.models import ( + Database, + DataSourceConnection, + DataTable, + DataTableColumn, + Schema, +) +from marimo._sql.engines.types import InferenceConfig, SQLConnection +from marimo._sql.utils import convert_to_output, sql_type_to_data_type +from marimo._types.ids import VariableName + +LOGGER = _loggers.marimo_logger() + + +def _table_schema_name(t: Any) -> str: + return str(t.var_schema) + + +class HotdataMarimoEngine(SQLConnection[HotdataClient]): + """Marimo :class:`~marimo._sql.engines.types.SQLConnection` backed by Hotdata. + + Catalog methods support Marimo's Data Sources panel. ``execute()`` only runs SQL + via :meth:`~hotdata_runtime.HotdataClient.execute_sql` (no catalog calls in that path). + """ + + def __init__( + self, + connection: HotdataClient, + engine_name: VariableName | None = None, + ) -> None: + super().__init__(connection, engine_name) + self._connections_cache: list[Any] | None = None + self._connection_id_cache: dict[str, str] | None = None + + @property + def source(self) -> str: + return "hotdata" + + @property + def dialect(self) -> str: + # Marimo labels engines as ``{dialect} ({variable_name})``; display_name is patched to "Hotdata". + return "hotdata" + + @staticmethod + def is_compatible(var: Any) -> bool: + return isinstance(var, HotdataClient) + + @property + def inference_config(self) -> InferenceConfig: + return InferenceConfig( + auto_discover_schemas=True, + auto_discover_tables="auto", + auto_discover_columns="auto", + ) + + def _resolve_should_auto_discover( + self, + value: bool | Literal["auto"], + ) -> bool: + if value == "auto": + return True + return value + + def _connection_ids(self) -> dict[str, str]: + if self._connection_id_cache is None: + self._connection_id_cache = { + str(c.name): str(c.id) for c in self._connections() + } + return self._connection_id_cache + + def _connection_id(self, connection_name: str) -> str | None: + return self._connection_ids().get(connection_name) + + def _connections(self) -> list[Any]: + if self._connections_cache is None: + self._connections_cache = list( + self._connection.connections().list_connections().connections + ) + return self._connections_cache + + def _iter_grouped( + self, + *, + connection_id: str | None, + include_columns: bool, + ) -> dict[str, dict[str, list[Any]]]: + grouped: dict[str, dict[str, list[Any]]] = defaultdict( + lambda: defaultdict(list) + ) + for t in self._connection.iter_tables( + connection_id=connection_id, + include_columns=include_columns, + ): + grouped[str(t.connection)][_table_schema_name(t)].append(t) + return grouped + + def get_default_database(self) -> str | None: + listing = self._connections() + if not listing: + return None + return str(listing[0].name) + + def get_default_schema(self) -> str | None: + return None + + def get_databases( + self, + *, + include_schemas: bool | Literal["auto"], + include_tables: bool | Literal["auto"], + include_table_details: bool | Literal["auto"], + ) -> list[Database]: + databases: list[Database] = [] + for c in self._connections(): + name = str(c.name) + if self._resolve_should_auto_discover(include_schemas): + schemas = self.get_schemas( + database=name, + include_tables=self._resolve_should_auto_discover( + include_tables + ), + include_table_details=self._resolve_should_auto_discover( + include_table_details + ), + ) + else: + schemas = [] + databases.append( + Database( + name=name, + dialect=self.dialect, + schemas=schemas, + engine=self._engine_name, + ) + ) + return databases + + def get_schemas( + self, + *, + database: str | None, + include_tables: bool, + include_table_details: bool, + ) -> list[Schema]: + if not database: + return [] + conn_id = self._connection_id(database) + if conn_id is None: + LOGGER.warning("Unknown Hotdata connection name %r", database) + return [] + grouped = self._iter_grouped( + connection_id=conn_id, + include_columns=include_table_details, + ) + inner = grouped.get(database, {}) + schemas: list[Schema] = [] + for schema_name in sorted(inner.keys()): + tables: list[DataTable] = [] + if include_tables: + tables = self.get_tables_in_schema( + schema=schema_name, + database=database, + include_table_details=include_table_details, + ) + if not tables: + continue + schemas.append(Schema(name=schema_name, tables=tables)) + return schemas + + def _data_table_from_table_info(self, t: Any) -> DataTable: + cols: list[DataTableColumn] = [] + for col in t.columns or []: + cols.append( + DataTableColumn( + name=str(col.name), + type=sql_type_to_data_type(str(col.data_type)), + external_type=str(col.data_type), + sample_values=[], + ) + ) + return DataTable( + source_type="connection", + source=self.source, + name=str(t.table), + num_rows=None, + num_columns=len(cols) if cols else None, + variable_name=None, + engine=self._engine_name, + type="table", + columns=cols, + primary_keys=None, + indexes=None, + ) + + def get_tables_in_schema( + self, + *, + schema: str, + database: str, + include_table_details: bool, + ) -> list[DataTable]: + conn_id = self._connection_id(database) + if conn_id is None: + return [] + grouped = self._iter_grouped( + connection_id=conn_id, + include_columns=include_table_details, + ) + tables_info = grouped.get(database, {}).get(schema, []) + out: list[DataTable] = [] + for t in sorted(tables_info, key=lambda x: str(x.table)): + if include_table_details: + if t.columns: + out.append(self._data_table_from_table_info(t)) + continue + dt = self.get_table_details( + table_name=str(t.table), + schema_name=schema, + database_name=database, + ) + if dt is not None: + out.append(dt) + else: + out.append( + DataTable( + source_type="connection", + source=self.source, + name=str(t.table), + num_rows=None, + num_columns=len(t.columns or []) if t.columns else None, + variable_name=None, + engine=self._engine_name, + type="table", + columns=[], + primary_keys=None, + indexes=None, + ) + ) + return out + + def get_table_details( + self, + *, + table_name: str, + schema_name: str, + database_name: str, + ) -> DataTable | None: + conn_id = self._connection_id(database_name) + if conn_id is None: + return None + qualified = f"{database_name}.{schema_name}.{table_name}" + try: + cols_raw = self._connection.columns_for_qualified( + qualified, connection_id=conn_id + ) + except Exception: + LOGGER.warning( + "Failed to load columns for %s", + qualified, + exc_info=True, + ) + return None + cols: list[DataTableColumn] = [] + for col in cols_raw: + cols.append( + DataTableColumn( + name=str(col.name), + type=sql_type_to_data_type(str(col.data_type)), + external_type=str(col.data_type), + sample_values=[], + ) + ) + return DataTable( + source_type="connection", + source=self.source, + name=table_name, + num_rows=None, + num_columns=len(cols), + variable_name=None, + engine=self._engine_name, + type="table", + columns=cols, + primary_keys=None, + indexes=None, + ) + + def execute(self, query: str) -> Any: + qr = self._connection.execute_sql(query) + fmt = self.sql_output_format() + + def to_polars() -> Any: + import polars as pl + + if not qr.columns: + return pl.DataFrame() + return pl.DataFrame(qr.rows, schema=qr.columns, orient="row") + + return convert_to_output( + sql_output_format=fmt, + to_polars=to_polars, + to_pandas=qr.to_pandas, + to_native=to_polars, + ) + + +_HOTDATA_ENGINE_DISPLAY_NAME = "Hotdata" +_ORIGINAL_ENGINE_TO_CONNECTION = None + + +def _install_hotdata_engine_display_name() -> None: + """Show ``Hotdata`` in Marimo's SQL engine / Data Sources UI (not ``sql (client)``).""" + global _ORIGINAL_ENGINE_TO_CONNECTION + if _ORIGINAL_ENGINE_TO_CONNECTION is not None: + return + + import marimo._sql.get_engines as ge + + _ORIGINAL_ENGINE_TO_CONNECTION = ge.engine_to_data_source_connection + + def engine_to_data_source_connection( + variable_name: VariableName, engine: object + ) -> DataSourceConnection: + conn = _ORIGINAL_ENGINE_TO_CONNECTION(variable_name, engine) # type: ignore[arg-type] + if not isinstance(engine, HotdataMarimoEngine): + return conn + return DataSourceConnection( + source=conn.source, + dialect=conn.dialect, + name=conn.name, + display_name=_HOTDATA_ENGINE_DISPLAY_NAME, + databases=conn.databases, + default_database=conn.default_database, + default_schema=conn.default_schema, + ) + + _set_engine_to_data_source_connection(engine_to_data_source_connection) + + +def _set_engine_to_data_source_connection(fn: object) -> None: + """Marimo imports this helper in multiple modules; patch all bindings.""" + import marimo._runtime.runner.hooks_post_execution as hpe + import marimo._runtime.runtime as rt + import marimo._sql.get_engines as ge + + ge.engine_to_data_source_connection = fn # type: ignore[assignment] + hpe.engine_to_data_source_connection = fn # type: ignore[assignment] + rt.engine_to_data_source_connection = fn # type: ignore[assignment] + + +def register_hotdata_sql_engine() -> None: + """Register :class:`HotdataMarimoEngine` with Marimo's SQL engine registry (idempotent).""" + _install_hotdata_engine_display_name() + from marimo._sql.get_engines import SUPPORTED_ENGINES + + if HotdataMarimoEngine in SUPPORTED_ENGINES: + return + SUPPORTED_ENGINES.insert(0, HotdataMarimoEngine) + + +def unregister_hotdata_sql_engine() -> None: + """Remove :class:`HotdataMarimoEngine` from Marimo's registry (mostly for tests).""" + from marimo._sql.get_engines import SUPPORTED_ENGINES + + while HotdataMarimoEngine in SUPPORTED_ENGINES: + SUPPORTED_ENGINES.remove(HotdataMarimoEngine) diff --git a/hotdata_marimo/table_browser.py b/hotdata_marimo/table_browser.py index e76c545..782be26 100644 --- a/hotdata_marimo/table_browser.py +++ b/hotdata_marimo/table_browser.py @@ -4,7 +4,19 @@ import marimo as mo -from hotdata_runtime.client import HotdataClient +from hotdata_runtime import HotdataClient + + +def _connection_options(conns: list[Any]) -> dict[str, str]: + counts: dict[str, int] = {} + options: dict[str, str] = {} + for c in conns: + label = c.name + count = counts.get(label, 0) + counts[label] = count + 1 + key = label if count == 0 else f"{label} ({c.id})" + options[key] = c.id + return options def connection_picker( @@ -21,7 +33,7 @@ def connection_picker( label=label, full_width=full_width, ) - options = {c.name: c.id for c in conns} + options = _connection_options(conns) return mo.ui.dropdown( options=options, label=label, @@ -182,7 +194,10 @@ def ui(self): stack.append(self.table_pick) return mo.vstack(stack, gap=1) - cols = self._client.columns_for_qualified(sel) + cols = self._client.columns_for_qualified( + sel, + connection_id=self.selected_connection_id, + ) if not cols: body = mo.md("_No column metadata returned (check catalog sync)._") else: diff --git a/hotdata_marimo/workspace_selector.py b/hotdata_marimo/workspace_selector.py index 6bd374b..9bdd63d 100644 --- a/hotdata_marimo/workspace_selector.py +++ b/hotdata_marimo/workspace_selector.py @@ -1,13 +1,12 @@ from __future__ import annotations import marimo as mo -from hotdata_runtime.client import HotdataClient -from hotdata_runtime.env import ( +from hotdata_runtime import ( + HotdataClient, default_api_key, default_host, default_session_id, - explicit_workspace_id, - list_workspaces, + resolve_workspace_selection, ) @@ -25,25 +24,19 @@ def __init__( self._api_key = api_key self._host = host or default_host() self._session_id = session_id - self._explicit = explicit_workspace_id() - - workspaces = list_workspaces(api_key, self._host, session_id) - if not workspaces: - raise RuntimeError("No Hotdata workspaces found for this API key.") - + selection = resolve_workspace_selection(api_key, self._host, session_id) + self._explicit = selection.source == "explicit_env" if self._explicit: self._pick = None - self._workspace_id = self._explicit + self._workspace_id = selection.workspace_id return + workspaces = selection.workspaces if len(workspaces) == 1: self._pick = None self._workspace_id = workspaces[0].public_id return - active = [w for w in workspaces if w.active] - chosen = active[0] if active else workspaces[0] - labels: list[tuple[str, str]] = [] seen: set[str] = set() for w in workspaces: @@ -52,10 +45,10 @@ def __init__( seen.add(base) labels.append((label_text, w.public_id)) - labels.sort(key=lambda t: 0 if t[1] == chosen.public_id else 1) + labels.sort(key=lambda t: 0 if t[1] == selection.workspace_id else 1) options = {k: v for k, v in labels} self._pick = mo.ui.dropdown(options=options, label=label, full_width=True) - self._workspace_id = chosen.public_id + self._workspace_id = selection.workspace_id @property def workspace_id(self) -> str: @@ -84,7 +77,7 @@ def ui(self): def workspace_selector_from_env(*, label: str = "Workspace") -> WorkspaceSelector: api_key = default_api_key() if not api_key: - raise RuntimeError("HOTDATA_API_KEY or HOTDATA_TOKEN must be set.") + raise RuntimeError("HOTDATA_API_KEY must be set.") host = default_host() session = default_session_id() return WorkspaceSelector( diff --git a/tests/test_architecture_guardrails.py b/tests/test_architecture_guardrails.py new file mode 100644 index 0000000..8b5ee3b --- /dev/null +++ b/tests/test_architecture_guardrails.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import re +from pathlib import Path + + +REPO_ROOT = Path(__file__).resolve().parents[1] +SOURCE_ROOT = REPO_ROOT / "hotdata_marimo" + + +def test_source_uses_hotdata_runtime_root_imports() -> None: + violations: list[str] = [] + pattern = re.compile( + r"(?m)^\s*(?:from\s+hotdata_runtime\.(client|env|result|health)\s+import" + r"|import\s+hotdata_runtime\.(client|env|result|health)(?:\s|$|,|as))" + ) + + for path in SOURCE_ROOT.rglob("*.py"): + text = path.read_text(encoding="utf-8") + if pattern.search(text): + violations.append(str(path.relative_to(REPO_ROOT))) + + assert not violations, ( + "Use `from hotdata_runtime import ...` in package source; " + f"found submodule imports in: {', '.join(violations)}" + ) diff --git a/tests/test_imports.py b/tests/test_imports.py index 5d8a348..041a3b8 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -3,3 +3,5 @@ def test_package_imports(): assert hm.HotdataClient is not None assert hm.SqlEditor is not None + assert hm.register_hotdata_sql_engine is not None + assert hm.HotdataMarimoEngine is not None diff --git a/tests/test_options.py b/tests/test_options.py new file mode 100644 index 0000000..38d975d --- /dev/null +++ b/tests/test_options.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from hotdata_marimo.display import _option_map_with_unique_labels +from hotdata_marimo.table_browser import _connection_options + + +def test_option_map_with_unique_labels_keeps_all_values(): + options = _option_map_with_unique_labels( + [("dup", "a"), ("dup", "b"), ("dup", "c")] + ) + assert options == { + "dup": "a", + "dup (2)": "b", + "dup (3)": "c", + } + + +def test_connection_options_disambiguates_duplicate_names(): + conns = [ + SimpleNamespace(name="Warehouse", id="conn_1"), + SimpleNamespace(name="Warehouse", id="conn_2"), + SimpleNamespace(name="Analytics", id="conn_3"), + ] + options = _connection_options(conns) + assert options == { + "Warehouse": "conn_1", + "Warehouse (conn_2)": "conn_2", + "Analytics": "conn_3", + } diff --git a/tests/test_sql_engine_registry.py b/tests/test_sql_engine_registry.py new file mode 100644 index 0000000..b836572 --- /dev/null +++ b/tests/test_sql_engine_registry.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import hotdata_marimo as hm +from hotdata_runtime import HotdataClient +from hotdata_marimo.sql_engine import HotdataMarimoEngine +from marimo._types.ids import VariableName + + +def test_register_hotdata_sql_engine_is_idempotent() -> None: + from marimo._sql.get_engines import SUPPORTED_ENGINES + + hm.unregister_hotdata_sql_engine() + assert SUPPORTED_ENGINES.count(HotdataMarimoEngine) == 0 + try: + hm.register_hotdata_sql_engine() + hm.register_hotdata_sql_engine() + assert SUPPORTED_ENGINES.count(HotdataMarimoEngine) == 1 + finally: + hm.unregister_hotdata_sql_engine() + + +def test_hotdata_engine_display_name_in_marimo_ui() -> None: + hm.register_hotdata_sql_engine() + try: + client = MagicMock(spec=HotdataClient) + client.connections.return_value.list_connections.return_value = ( + SimpleNamespace(connections=[]) + ) + engine = HotdataMarimoEngine(client, engine_name=VariableName("client")) + import marimo._sql.get_engines as ge + + conn = ge.engine_to_data_source_connection(VariableName("client"), engine) + assert conn.display_name == "Hotdata" + + import marimo._runtime.runner.hooks_post_execution as hpe + + conn_hpe = hpe.engine_to_data_source_connection( + VariableName("client"), engine + ) + assert conn_hpe.display_name == "Hotdata" + finally: + hm.unregister_hotdata_sql_engine() diff --git a/uv.lock b/uv.lock index d01e539..08e605f 100644 --- a/uv.lock +++ b/uv.lock @@ -222,7 +222,10 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "pytest", specifier = ">=8.0" }] +dev = [ + { name = "packaging", specifier = ">=23" }, + { name = "pytest", specifier = ">=8.0" }, +] [[package]] name = "idna"