diff --git a/ggsql-python/python/ggsql/__init__.py b/ggsql-python/python/ggsql/__init__.py index f9356e72..42b2338e 100644 --- a/ggsql-python/python/ggsql/__init__.py +++ b/ggsql-python/python/ggsql/__init__.py @@ -1,11 +1,12 @@ from __future__ import annotations import json -from typing import Any, Union +from typing import Any, Protocol, Union, runtime_checkable import altair import narwhals as nw from narwhals.typing import IntoFrame +import polars as pl from ggsql._ggsql import ( DuckDBReader, @@ -14,6 +15,10 @@ Spec, validate, execute, + ParseError, + ValidationError, + ReaderError, + WriterError, ) __all__ = [ @@ -22,12 +27,18 @@ "VegaLiteWriter", "Validated", "Spec", + "Reader", # Functions "validate", "execute", "render_altair", + # Exceptions + "ParseError", + "ValidationError", + "ReaderError", + "WriterError", ] -__version__ = "0.1.0" +__version__ = "0.1.4" # Type alias for any Altair chart type AltairChart = Union[ @@ -41,6 +52,29 @@ ] +@runtime_checkable +class Reader(Protocol): + """Protocol for ggsql database readers. + + Any object implementing these methods can be used as a reader with + ``ggsql.execute()``. Native readers like ``DuckDBReader`` satisfy + this protocol automatically. + + Required methods + ---------------- + execute_sql(sql: str) -> polars.DataFrame + Execute a SQL query and return results as a polars DataFrame. + register(name: str, df: polars.DataFrame, replace: bool = False) -> None + Register a DataFrame as a named table for SQL queries. + """ + + def execute_sql(self, sql: str) -> pl.DataFrame: ... + + def register( + self, name: str, df: pl.DataFrame, replace: bool = False + ) -> None: ... + + def _json_to_altair_chart(vegalite_json: str, **kwargs: Any) -> AltairChart: """Convert a Vega-Lite JSON string to the appropriate Altair chart type.""" spec = json.loads(vegalite_json) diff --git a/ggsql-python/src/lib.rs b/ggsql-python/src/lib.rs index 27d26b68..91710c21 100644 --- a/ggsql-python/src/lib.rs +++ b/ggsql-python/src/lib.rs @@ -2,6 +2,8 @@ // See: https://github.com/PyO3/pyo3/issues/4327 #![allow(clippy::useless_conversion)] +use pyo3::create_exception; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyDict, PyList}; use std::io::Cursor; @@ -12,6 +14,48 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning}; use ggsql::writer::{VegaLiteWriter as RustVegaLiteWriter, Writer as RustWriter}; use ggsql::GgsqlError; +// ============================================================================ +// Custom Exception Classes +// ============================================================================ + +// All subclass ValueError for backwards compatibility +create_exception!( + ggsql, + ParseError, + PyValueError, + "Raised on query syntax errors." +); +create_exception!( + ggsql, + ValidationError, + PyValueError, + "Raised on semantic validation errors." +); +create_exception!( + ggsql, + ReaderError, + PyValueError, + "Raised on data source errors." +); +create_exception!( + ggsql, + WriterError, + PyValueError, + "Raised on output generation errors." +); + +/// Convert a GgsqlError to the appropriate typed Python exception. +fn ggsql_err_to_py(e: GgsqlError) -> PyErr { + let msg = e.to_string(); + match e { + GgsqlError::ParseError(_) => PyErr::new::(msg), + GgsqlError::ValidationError(_) => PyErr::new::(msg), + GgsqlError::ReaderError(_) => PyErr::new::(msg), + GgsqlError::WriterError(_) => PyErr::new::(msg), + GgsqlError::InternalError(_) => PyErr::new::(msg), + } +} + use polars::prelude::{DataFrame, IpcReader, IpcWriter, SerReader, SerWriter}; // ============================================================================ @@ -142,9 +186,13 @@ impl Reader for PyReaderBridge { Python::attach(|py| { let py_df = polars_to_py(py, &df).map_err(|e| GgsqlError::ReaderError(e.to_string()))?; + let kwargs = PyDict::new(py); + kwargs + .set_item("replace", replace) + .map_err(|e| GgsqlError::ReaderError(e.to_string()))?; self.obj .bind(py) - .call_method1("register", (name, py_df, replace)) + .call_method("register", (name, py_df), Some(&kwargs)) .map_err(|e| GgsqlError::ReaderError(format!("Reader.register() failed: {}", e)))?; Ok(()) }) @@ -175,7 +223,7 @@ macro_rules! try_native_readers { if let Ok(native) = $reader.downcast::<$native_type>() { return native.borrow().inner.execute($query) .map(|s| PySpec { inner: s }) - .map_err(|e| PyErr::new::(e.to_string())); + .map_err(ggsql_err_to_py); } )* }}; @@ -224,8 +272,8 @@ impl PyDuckDBReader { /// If the connection string is invalid or the database cannot be opened. #[new] fn new(connection: &str) -> PyResult { - let inner = RustDuckDBReader::from_connection_string(connection) - .map_err(|e| PyErr::new::(e.to_string()))?; + let inner = + RustDuckDBReader::from_connection_string(connection).map_err(ggsql_err_to_py)?; Ok(Self { inner }) } @@ -255,7 +303,7 @@ impl PyDuckDBReader { let rust_df = py_to_polars(py, df)?; self.inner .register(name, rust_df, replace) - .map_err(|e| PyErr::new::(e.to_string())) + .map_err(ggsql_err_to_py) } /// Unregister a previously registered table. @@ -270,9 +318,7 @@ impl PyDuckDBReader { /// ValueError /// If the table wasn't registered via this reader or unregistration fails. fn unregister(&self, name: &str) -> PyResult<()> { - self.inner - .unregister(name) - .map_err(|e| PyErr::new::(e.to_string())) + self.inner.unregister(name).map_err(ggsql_err_to_py) } /// Execute a SQL query and return the result as a DataFrame. @@ -292,10 +338,7 @@ impl PyDuckDBReader { /// ValueError /// If the SQL is invalid or execution fails. fn execute_sql(&self, py: Python<'_>, sql: &str) -> PyResult> { - let df = self - .inner - .execute_sql(sql) - .map_err(|e| PyErr::new::(e.to_string()))?; + let df = self.inner.execute_sql(sql).map_err(ggsql_err_to_py)?; polars_to_py(py, &df) } @@ -330,7 +373,7 @@ impl PyDuckDBReader { self.inner .execute(query) .map(|s| PySpec { inner: s }) - .map_err(|e| PyErr::new::(e.to_string())) + .map_err(ggsql_err_to_py) } } @@ -391,9 +434,7 @@ impl PyVegaLiteWriter { /// >>> writer = VegaLiteWriter() /// >>> json_output = writer.render(spec) fn render(&self, spec: &PySpec) -> PyResult { - self.inner - .render(&spec.inner) - .map_err(|e| PyErr::new::(e.to_string())) + self.inner.render(&spec.inner).map_err(ggsql_err_to_py) } } @@ -657,8 +698,7 @@ impl PySpec { /// If validation fails unexpectedly (not for syntax errors, which are captured). #[pyfunction] fn validate(query: &str) -> PyResult { - let v = rust_validate(query) - .map_err(|e| PyErr::new::(e.to_string()))?; + let v = rust_validate(query).map_err(ggsql_err_to_py)?; Ok(PyValidated { sql: v.sql().to_string(), @@ -739,7 +779,7 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult { bridge .execute(query) .map(|s| PySpec { inner: s }) - .map_err(|e| PyErr::new::(e.to_string())) + .map_err(ggsql_err_to_py) } // ============================================================================ @@ -748,6 +788,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult { #[pymodule] fn _ggsql(m: &Bound<'_, PyModule>) -> PyResult<()> { + // Exceptions + m.add("ParseError", m.py().get_type::())?; + m.add("ValidationError", m.py().get_type::())?; + m.add("ReaderError", m.py().get_type::())?; + m.add("WriterError", m.py().get_type::())?; + // Classes m.add_class::()?; m.add_class::()?; diff --git a/ggsql-python/tests/test_ggsql.py b/ggsql-python/tests/test_ggsql.py index fbe4b131..962a628d 100644 --- a/ggsql-python/tests/test_ggsql.py +++ b/ggsql-python/tests/test_ggsql.py @@ -402,7 +402,7 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.conn.execute(sql).pl() - def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.conn.register(name, df) reader = RegisterReader() @@ -453,7 +453,7 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.conn.execute(sql).pl() - def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.conn.register(name, df) reader = DuckDBBackedReader() @@ -484,7 +484,7 @@ def execute_sql(self, sql: str) -> pl.DataFrame: self.execute_calls.append(sql) return self.conn.execute(sql).pl() - def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.conn.register(name, df) reader = RecordingReader() @@ -532,6 +532,92 @@ def unregister(self, name: str) -> None: assert "point" in json_output +class TestExceptions: + """Tests for typed exception classes.""" + + def test_parse_error_on_invalid_syntax(self): + """Invalid syntax raises ParseError when executing.""" + with pytest.raises(ggsql.ParseError): + reader = ggsql.DuckDBReader("duckdb://memory") + reader.execute("SELECT 1 AS x VISUALISE DRAW not_a_geom") + + def test_parse_error_is_value_error(self): + """ParseError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.ParseError, ValueError) + + def test_validation_error_on_missing_aesthetic(self): + """Missing required aesthetic raises ValidationError.""" + with pytest.raises(ggsql.ValidationError): + reader = ggsql.DuckDBReader("duckdb://memory") + reader.execute("SELECT 1 AS x VISUALISE DRAW point MAPPING x AS x") + + def test_validation_error_is_value_error(self): + """ValidationError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.ValidationError, ValueError) + + def test_reader_error_on_bad_sql(self): + """Bad SQL raises ReaderError.""" + with pytest.raises(ggsql.ReaderError): + reader = ggsql.DuckDBReader("duckdb://memory") + reader.execute( + "SELECT * FROM nonexistent_table VISUALISE DRAW point MAPPING x AS x, y AS y" + ) + + def test_reader_error_is_value_error(self): + """ReaderError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.ReaderError, ValueError) + + def test_writer_error_is_value_error(self): + """WriterError is a subclass of ValueError for backwards compat.""" + assert issubclass(ggsql.WriterError, ValueError) + + def test_all_exceptions_exported(self): + """All exception classes are accessible from ggsql module.""" + assert hasattr(ggsql, "ParseError") + assert hasattr(ggsql, "ValidationError") + assert hasattr(ggsql, "ReaderError") + assert hasattr(ggsql, "WriterError") + + +class TestReaderProtocol: + """Tests for Reader protocol.""" + + def test_duckdb_reader_is_reader(self): + """Native DuckDBReader satisfies the Reader protocol.""" + reader = ggsql.DuckDBReader("duckdb://memory") + assert isinstance(reader, ggsql.Reader) + + def test_custom_reader_is_reader(self): + """Custom reader with correct methods satisfies the Reader protocol.""" + + class MyReader: + def execute_sql(self, sql: str) -> pl.DataFrame: + return pl.DataFrame({"x": [1]}) + + def register( + self, name: str, df: pl.DataFrame, replace: bool = False + ) -> None: + pass + + reader = MyReader() + assert isinstance(reader, ggsql.Reader) + + def test_incomplete_reader_is_not_reader(self): + """Object missing required methods is not a Reader.""" + + class NotAReader: + def execute_sql(self, sql: str) -> pl.DataFrame: + return pl.DataFrame({"x": [1]}) + # Missing register() + + obj = NotAReader() + assert not isinstance(obj, ggsql.Reader) + + def test_reader_is_exported(self): + """Reader is accessible from ggsql module.""" + assert hasattr(ggsql, "Reader") + + class TestVegaLiteWriterRenderChart: """Tests for VegaLiteWriter.render_chart() method.""" @@ -568,4 +654,3 @@ def test_render_chart_facet(self): writer = ggsql.VegaLiteWriter() chart = writer.render_chart(spec, validate=False) assert isinstance(chart, altair.FacetChart) -