diff --git a/README.md b/README.md index 3efcecc..4fe0387 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,38 @@ The `Store` class supports the following options: Use `Store.for_download()` as a convenient shorthand for storing results as a single Parquet file with a presigned URL. +### Execution progress + +You can monitor the progress of running queries by registering a +progress handler on the connection. + +```python +from wherobots.db import connect, ProgressInfo +from wherobots.db.region import Region +from wherobots.db.runtime import Runtime + +def on_progress(info: ProgressInfo) -> None: + print(f"{info.tasks_completed}/{info.tasks_total} tasks " + f"({info.tasks_active} active)") + +with connect( + api_key='...', + runtime=Runtime.TINY, + region=Region.AWS_US_WEST_2) as conn: + conn.set_progress_handler(on_progress) + curr = conn.cursor() + curr.execute("SELECT ...") + results = curr.fetchall() +``` + +The handler receives a `ProgressInfo` object with `execution_id`, +`tasks_total`, `tasks_completed`, and `tasks_active` fields. Pass +`None` to `set_progress_handler()` to disable progress reporting. + +Progress events are best-effort and may not be available for all query +types or server versions. The handler is simply not invoked when no +progress information is available. + ### Runtime and region selection You can chose the Wherobots runtime you want to use using the `runtime` diff --git a/pyproject.toml b/pyproject.toml index db20516..720fe39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "wherobots-python-dbapi" -version = "0.23.2" +version = "0.24.0" description = "Python DB-API driver for Wherobots DB" authors = [{ name = "Maxime Petazzoni", email = "max@wherobots.com" }] requires-python = ">=3.10, <4" diff --git a/tests/smoke.py b/tests/smoke.py index ab65639..efeb4ea 100644 --- a/tests/smoke.py +++ b/tests/smoke.py @@ -10,9 +10,9 @@ from rich.console import Console from rich.table import Table -from wherobots.db import connect, connect_direct, errors -from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE +from wherobots.db import connect, connect_direct, errors, ProgressInfo from wherobots.db.connection import Connection +from wherobots.db.constants import DEFAULT_ENDPOINT, DEFAULT_SESSION_TYPE from wherobots.db.region import Region from wherobots.db.runtime import Runtime from wherobots.db.session_type import SessionType @@ -54,6 +54,11 @@ parser.add_argument( "--wide", help="Enable wide output", action="store_const", const=80, default=30 ) + parser.add_argument( + "--progress", + help="Enable execution progress reporting", + action="store_true", + ) parser.add_argument("sql", nargs="+", help="SQL query to execute") args = parser.parse_args() @@ -134,6 +139,26 @@ def execute(conn: Connection, sql: str) -> pandas.DataFrame | StoreResult: try: with conn_func() as conn: + if args.progress: + console = Console(stderr=True) + + def _on_progress(info: ProgressInfo) -> None: + pct = ( + f"{info.tasks_completed / info.tasks_total * 100:.0f}%" + if info.tasks_total + else "?" + ) + console.print( + f" [dim]\\[progress][/dim] " + f"[bold]{pct}[/bold] " + f"{info.tasks_completed}/{info.tasks_total} tasks " + f"[dim]({info.tasks_active} active)[/dim] " + f"[dim]{info.execution_id[:8]}[/dim]", + highlight=False, + ) + + conn.set_progress_handler(_on_progress) + with concurrent.futures.ThreadPoolExecutor() as pool: futures = [pool.submit(execute, conn, s) for s in args.sql] for future in concurrent.futures.as_completed(futures): diff --git a/wherobots/db/__init__.py b/wherobots/db/__init__.py index 5844678..667daf5 100644 --- a/wherobots/db/__init__.py +++ b/wherobots/db/__init__.py @@ -10,7 +10,7 @@ ProgrammingError, NotSupportedError, ) -from .models import Store, StoreResult +from .models import ProgressInfo, Store, StoreResult from .region import Region from .runtime import Runtime from .types import StorageFormat @@ -18,6 +18,7 @@ __all__ = [ "Connection", "Cursor", + "ProgressInfo", "connect", "connect_direct", "Error", diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 04044a8..862abe4 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -16,7 +16,7 @@ from .constants import DEFAULT_READ_TIMEOUT_SECONDS from .cursor import Cursor from .errors import NotSupportedError, OperationalError -from .models import ExecutionResult, Store, StoreResult +from .models import ExecutionResult, ProgressInfo, Store, StoreResult from .types import ( RequestKind, EventKind, @@ -27,6 +27,10 @@ ) +ProgressHandler = Callable[[ProgressInfo], None] +"""A callable invoked with a :class:`ProgressInfo` on every progress event.""" + + @dataclass class Query: sql: str @@ -64,6 +68,7 @@ def __init__( self.__results_format = results_format self.__data_compression = data_compression self.__geometry_representation = geometry_representation + self.__progress_handler: ProgressHandler | None = None self.__queries: dict[str, Query] = {} self.__thread = threading.Thread( @@ -89,6 +94,21 @@ def rollback(self) -> None: def cursor(self) -> Cursor: return Cursor(self.__execute_sql, self.__cancel_query) + def set_progress_handler(self, handler: ProgressHandler | None) -> None: + """Register a callback invoked for execution progress events. + + When a handler is set, every ``execute_sql`` request automatically + includes ``enable_progress_events: true`` so the SQL session streams + progress updates for running queries. + + Pass ``None`` to disable progress reporting. + + This follows the `sqlite3 Connection.set_progress_handler() + `_ + pattern (PEP 249 vendor extension). + """ + self.__progress_handler = handler + def __main_loop(self) -> None: """Main background loop listening for messages from the SQL session.""" logging.info("Starting background connection handling loop...") @@ -116,6 +136,25 @@ def __listen(self) -> None: # Invalid event. return + # Progress events are independent of the query state machine and don't + # require a tracked query — the handler is connection-level. + if kind == EventKind.EXECUTION_PROGRESS: + handler = self.__progress_handler + if handler is None: + return + try: + handler( + ProgressInfo( + execution_id=execution_id, + tasks_total=message.get("tasks_total", 0), + tasks_completed=message.get("tasks_completed", 0), + tasks_active=message.get("tasks_active", 0), + ) + ) + except Exception: + logging.exception("Progress handler raised an exception") + return + query = self.__queries.get(execution_id) if not query: logging.warning( @@ -236,6 +275,9 @@ def __execute_sql( "statement": sql, } + if self.__progress_handler is not None: + request["enable_progress_events"] = True + if store: request["store"] = { "format": store.format.value, diff --git a/wherobots/db/models.py b/wherobots/db/models.py index 763b130..3a0a939 100644 --- a/wherobots/db/models.py +++ b/wherobots/db/models.py @@ -78,3 +78,16 @@ class ExecutionResult: results: pandas.DataFrame | None = None error: Exception | None = None store_result: StoreResult | None = None + + +@dataclass(frozen=True) +class ProgressInfo: + """Progress information for a running query. + + Mirrors the ``execution_progress`` event sent by the SQL session. + """ + + execution_id: str + tasks_total: int + tasks_completed: int + tasks_active: int diff --git a/wherobots/db/types.py b/wherobots/db/types.py index d0e021d..4c0745e 100644 --- a/wherobots/db/types.py +++ b/wherobots/db/types.py @@ -45,6 +45,7 @@ class EventKind(LowercaseStrEnum): STATE_UPDATED = auto() EXECUTION_RESULT = auto() ERROR = auto() + EXECUTION_PROGRESS = auto() class ResultsFormat(LowercaseStrEnum):