diff --git a/README.md b/README.md index de96fda..5067003 100644 --- a/README.md +++ b/README.md @@ -103,3 +103,63 @@ users may find useful: Consider multi-session for potential cost savings, but be mindful of performance impacts from shared resources. You might need to adjust cluster size if slowdowns occur, which could affect overall cost. + +### Storing results to cloud storage + +Instead of receiving query results inline over the WebSocket connection, +you can have the server write them to cloud storage (S3) using the +`Store` class. This is useful for large result sets or when you need a +downloadable file. + +```python +from wherobots.db import connect, Store, StorageFormat +from wherobots.db.region import Region +from wherobots.db.runtime import Runtime + +with connect( + api_key='...', + runtime=Runtime.TINY, + region=Region.AWS_US_WEST_2) as conn: + curr = conn.cursor() + + # Store results as a single GeoJSON file with a presigned download URL + store = Store.for_download(format=StorageFormat.GEOJSON) + curr.execute("SELECT * FROM my_table", store=store) + results = curr.fetchall() +``` + +#### Store options + +You can pass format-specific Spark write options through the `options` +parameter. These correspond to the options available in Spark's +`DataFrameWriter` and are applied after the server's default options, +allowing you to override them. + +```python +# CSV without headers +store = Store.for_download( + format=StorageFormat.CSV, + options={"header": "false", "delimiter": "|"}, +) + +# GeoJSON preserving null fields +store = Store.for_download( + format=StorageFormat.GEOJSON, + options={"ignoreNullFields": "false"}, +) +``` + +You can also set a default `Store` at connection time, which will be +used for all queries executed through cursors created from that +connection unless overridden per-query: + +```python +with connect( + api_key='...', + runtime=Runtime.TINY, + region=Region.AWS_US_WEST_2, + store=Store.for_download(format=StorageFormat.PARQUET)) as conn: + curr = conn.cursor() + # All queries through this cursor will use the connection-level store + curr.execute("SELECT * FROM my_table") +``` diff --git a/tests/test_result_store.py b/tests/test_result_store.py new file mode 100644 index 0000000..2badf37 --- /dev/null +++ b/tests/test_result_store.py @@ -0,0 +1,170 @@ +"""Tests for result_store module: Store dataclass and StorageFormat enum.""" + +import json +import pytest + +from wherobots.db.result_store import Store, StorageFormat, DEFAULT_STORAGE_FORMAT + + +class TestStorageFormat: + def test_values(self): + assert StorageFormat.PARQUET.value == "parquet" + assert StorageFormat.CSV.value == "csv" + assert StorageFormat.GEOJSON.value == "geojson" + + def test_default_format(self): + assert DEFAULT_STORAGE_FORMAT == StorageFormat.PARQUET + + +class TestStore: + def test_default_construction(self): + store = Store() + assert store.format == StorageFormat.PARQUET + assert store.single is False + assert store.generate_presigned_url is False + assert store.options is None + + def test_with_format(self): + store = Store(format=StorageFormat.CSV) + assert store.format == StorageFormat.CSV + assert store.options is None + + def test_with_options(self): + store = Store( + format=StorageFormat.GEOJSON, + options={"ignoreNullFields": "false"}, + ) + assert store.options == {"ignoreNullFields": "false"} + + def test_with_multiple_options(self): + opts = {"header": "false", "delimiter": "|", "quote": '"'} + store = Store(format=StorageFormat.CSV, options=opts) + assert store.options == opts + + def test_empty_options_normalized_to_none(self): + store = Store(options={}) + assert store.options is None + + def test_none_options(self): + store = Store(options=None) + assert store.options is None + + def test_options_defensively_copied(self): + original = {"key": "value"} + store = Store(options=original) + # Mutating the original should not affect the store + original["key"] = "changed" + assert store.options == {"key": "value"} + + def test_frozen_dataclass(self): + store = Store() + with pytest.raises(AttributeError): + store.format = StorageFormat.CSV + + def test_presigned_url_requires_single(self): + with pytest.raises(ValueError, match="single=True"): + Store(generate_presigned_url=True, single=False) + + def test_presigned_url_with_single(self): + store = Store(single=True, generate_presigned_url=True) + assert store.single is True + assert store.generate_presigned_url is True + + +class TestStoreForDownload: + def test_default(self): + store = Store.for_download() + assert store.format == StorageFormat.PARQUET + assert store.single is True + assert store.generate_presigned_url is True + assert store.options is None + + def test_with_format(self): + store = Store.for_download(format=StorageFormat.CSV) + assert store.format == StorageFormat.CSV + assert store.single is True + assert store.generate_presigned_url is True + + def test_with_options(self): + store = Store.for_download( + format=StorageFormat.GEOJSON, + options={"ignoreNullFields": "false"}, + ) + assert store.format == StorageFormat.GEOJSON + assert store.options == {"ignoreNullFields": "false"} + + +class TestStoreToDict: + def test_without_options(self): + store = Store(format=StorageFormat.PARQUET, single=True) + d = store.to_dict() + assert d == { + "format": "parquet", + "single": True, + "generate_presigned_url": False, + } + assert "options" not in d + + def test_with_options(self): + store = Store( + format=StorageFormat.GEOJSON, + single=True, + generate_presigned_url=True, + options={"ignoreNullFields": "false"}, + ) + d = store.to_dict() + assert d == { + "format": "geojson", + "single": True, + "generate_presigned_url": True, + "options": {"ignoreNullFields": "false"}, + } + + def test_serializable_to_json(self): + store = Store.for_download( + format=StorageFormat.CSV, + options={"header": "false"}, + ) + serialized = json.dumps(store.to_dict()) + deserialized = json.loads(serialized) + assert deserialized["format"] == "csv" + assert deserialized["single"] is True + assert deserialized["generate_presigned_url"] is True + assert deserialized["options"] == {"header": "false"} + + def test_to_dict_returns_copy(self): + """Mutating the returned dict should not affect the Store.""" + store = Store(options={"key": "value"}) + d = store.to_dict() + d["options"]["key"] = "changed" + assert store.options == {"key": "value"} + + def test_full_execute_sql_request_shape(self): + """Verify the dict integrates correctly into an execute_sql request.""" + store = Store.for_download( + format=StorageFormat.GEOJSON, + options={"ignoreNullFields": "false"}, + ) + request = { + "kind": "execute_sql", + "execution_id": "test-id", + "statement": "SELECT 1", + "store": store.to_dict(), + } + serialized = json.dumps(request) + parsed = json.loads(serialized) + assert parsed["store"]["format"] == "geojson" + assert parsed["store"]["single"] is True + assert parsed["store"]["generate_presigned_url"] is True + assert parsed["store"]["options"] == {"ignoreNullFields": "false"} + + def test_request_without_store(self): + """Without a store, the request should not have a store key.""" + request = { + "kind": "execute_sql", + "execution_id": "test-id", + "statement": "SELECT 1", + } + serialized = json.dumps(request) + parsed = json.loads(serialized) + assert "store" not in parsed diff --git a/wherobots/db/__init__.py b/wherobots/db/__init__.py index 3e9a96e..127a415 100644 --- a/wherobots/db/__init__.py +++ b/wherobots/db/__init__.py @@ -11,6 +11,7 @@ NotSupportedError, ) from .region import Region +from .result_store import Store, StorageFormat from .runtime import Runtime __all__ = [ @@ -27,4 +28,6 @@ "NotSupportedError", "Region", "Runtime", + "Store", + "StorageFormat", ] diff --git a/wherobots/db/connection.py b/wherobots/db/connection.py index 47bbf61..06f478a 100644 --- a/wherobots/db/connection.py +++ b/wherobots/db/connection.py @@ -24,6 +24,7 @@ ) from wherobots.db.cursor import Cursor from wherobots.db.errors import NotSupportedError, OperationalError +from wherobots.db.result_store import Store @dataclass @@ -56,12 +57,14 @@ def __init__( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ): self.__ws = ws self.__read_timeout = read_timeout self.__results_format = results_format self.__data_compression = data_compression self.__geometry_representation = geometry_representation + self.__store = store self.__queries: dict[str, Query] = {} self.__thread = threading.Thread( @@ -85,7 +88,7 @@ def rollback(self) -> None: raise NotSupportedError def cursor(self) -> Cursor: - return Cursor(self.__execute_sql, self.__cancel_query) + return Cursor(self.__execute_sql, self.__cancel_query, self.__store) def __main_loop(self) -> None: """Main background loop listening for messages from the SQL session.""" @@ -200,7 +203,12 @@ def __recv(self) -> Dict[str, Any]: raise ValueError("Unexpected frame type received") return message - def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str: + def __execute_sql( + self, + sql: str, + handler: Callable[[Any], None], + store: Union[Store, None] = None, + ) -> str: """Triggers the execution of the given SQL query.""" execution_id = str(uuid.uuid4()) request = { @@ -209,6 +217,9 @@ def __execute_sql(self, sql: str, handler: Callable[[Any], None]) -> str: "statement": sql, } + if store is not None: + request["store"] = store.to_dict() + self.__queries[execution_id] = Query( sql=sql, execution_id=execution_id, diff --git a/wherobots/db/cursor.py b/wherobots/db/cursor.py index b12187e..592a436 100644 --- a/wherobots/db/cursor.py +++ b/wherobots/db/cursor.py @@ -1,7 +1,8 @@ import queue -from typing import Any, Optional, List, Tuple, Dict +from typing import Any, Optional, List, Tuple, Dict, Union from .errors import DatabaseError, ProgrammingError +from .result_store import Store _TYPE_MAP = { "object": "STRING", @@ -15,9 +16,15 @@ class Cursor: - def __init__(self, exec_fn, cancel_fn) -> None: + def __init__( + self, + exec_fn, + cancel_fn, + default_store: Union[Store, None] = None, + ) -> None: self.__exec_fn = exec_fn self.__cancel_fn = cancel_fn + self.__default_store = default_store self.__queue: queue.Queue = queue.Queue() self.__results: Optional[list[Any]] = None @@ -71,7 +78,12 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]: return self.__results - def execute(self, operation: str, parameters: Dict[str, Any] = None) -> None: + def execute( + self, + operation: str, + parameters: Dict[str, Any] = None, + store: Union[Store, None] = None, + ) -> None: if self.__current_execution_id: self.__cancel_fn(self.__current_execution_id) @@ -83,7 +95,10 @@ def execute(self, operation: str, parameters: Dict[str, Any] = None) -> None: sql = ( operation.replace("{", "{{").replace("}", "}}").format(**(parameters or {})) ) - self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result) + effective_store = store if store is not None else self.__default_store + self.__current_execution_id = self.__exec_fn( + sql, self.__on_execution_result, effective_store + ) def executemany( self, operation: str, seq_of_parameters: List[Dict[str, Any]] diff --git a/wherobots/db/driver.py b/wherobots/db/driver.py index 81573b1..03a598a 100644 --- a/wherobots/db/driver.py +++ b/wherobots/db/driver.py @@ -37,6 +37,7 @@ OperationalError, ) from .region import Region +from .result_store import Store from .runtime import Runtime apilevel = "2.0" @@ -69,6 +70,7 @@ def connect( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ) -> Connection: if not token and not api_key: raise ValueError("At least one of `token` or `api_key` is required") @@ -151,6 +153,7 @@ def get_session_uri() -> str: results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + store=store, ) @@ -171,6 +174,7 @@ def connect_direct( results_format: Union[ResultsFormat, None] = None, data_compression: Union[DataCompression, None] = None, geometry_representation: Union[GeometryRepresentation, None] = None, + store: Union[Store, None] = None, ) -> Connection: uri_with_protocol = f"{uri}/{protocol}" @@ -193,4 +197,5 @@ def connect_direct( results_format=results_format, data_compression=data_compression, geometry_representation=geometry_representation, + store=store, ) diff --git a/wherobots/db/result_store.py b/wherobots/db/result_store.py new file mode 100644 index 0000000..b2ef95f --- /dev/null +++ b/wherobots/db/result_store.py @@ -0,0 +1,86 @@ +"""Result storage configuration for Wherobots DB queries. + +Provides the :class:`StorageFormat` enum and :class:`Store` dataclass to configure +how query results are stored in cloud storage (e.g. S3) instead of being returned +directly over the WebSocket connection. +""" + +from dataclasses import dataclass, field +from enum import auto +from typing import Dict, Optional + +from strenum import LowercaseStrEnum + + +class StorageFormat(LowercaseStrEnum): + """Supported formats for storing query results.""" + + PARQUET = auto() + CSV = auto() + GEOJSON = auto() + + +DEFAULT_STORAGE_FORMAT = StorageFormat.PARQUET + + +@dataclass(frozen=True) +class Store: + """Configuration for storing query results to cloud storage. + + When a :class:`Store` is provided on a cursor's ``execute()`` call, the query results + are written to cloud storage in the specified format rather than being returned inline + over the WebSocket connection. + + :param format: The storage format (parquet, csv, geojson). Defaults to parquet. + :param single: Whether to coalesce results into a single file. + :param generate_presigned_url: Whether to generate a presigned download URL. + Only valid when ``single=True``. + :param options: Optional format-specific Spark DataFrameWriter options, + e.g. ``{"ignoreNullFields": "false"}`` for GeoJSON or ``{"header": "false"}`` for CSV. + These are applied after the server's default options and can override them. + """ + + format: StorageFormat = DEFAULT_STORAGE_FORMAT + single: bool = False + generate_presigned_url: bool = False + options: Optional[Dict[str, str]] = field(default=None) + + def __post_init__(self): + if self.generate_presigned_url and not self.single: + raise ValueError("Can only generate a presigned URL when single=True") + # Normalize empty options to None + if self.options is not None and len(self.options) == 0: + object.__setattr__(self, "options", None) + # Defensive copy: make options immutable + if self.options is not None: + object.__setattr__(self, "options", dict(self.options)) + + def to_dict(self) -> Dict: + """Serialize to a dict suitable for the WebSocket protocol.""" + d = { + "format": self.format.value, + "single": self.single, + "generate_presigned_url": self.generate_presigned_url, + } + if self.options is not None: + d["options"] = dict(self.options) + return d + + @staticmethod + def for_download( + format: StorageFormat = DEFAULT_STORAGE_FORMAT, + options: Optional[Dict[str, str]] = None, + ) -> "Store": + """Create a Store configured for single-file download with a presigned URL. + + This is the most common configuration for programmatic result retrieval. + + :param format: The storage format. Defaults to parquet. + :param options: Optional format-specific write options. + """ + return Store( + format=format, + single=True, + generate_presigned_url=True, + options=options, + )