Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions ggsql-python/python/ggsql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import json
from typing import Any, Union
from typing import Any, Protocol, Union, runtime_checkable

import altair
import narwhals as nw
from narwhals.typing import IntoFrame
import polars as pl

from ggsql._ggsql import (
DuckDBReader,
Expand All @@ -14,6 +15,10 @@
Spec,
validate,
execute,
ParseError,
ValidationError,
ReaderError,
WriterError,
)

__all__ = [
Expand All @@ -22,12 +27,18 @@
"VegaLiteWriter",
"Validated",
"Spec",
"Reader",
# Functions
"validate",
"execute",
"render_altair",
# Exceptions
"ParseError",
"ValidationError",
"ReaderError",
"WriterError",
]
__version__ = "0.1.0"
__version__ = "0.1.4"

# Type alias for any Altair chart type
AltairChart = Union[
Expand All @@ -41,6 +52,29 @@
]


@runtime_checkable
class Reader(Protocol):
"""Protocol for ggsql database readers.

Any object implementing these methods can be used as a reader with
``ggsql.execute()``. Native readers like ``DuckDBReader`` satisfy
this protocol automatically.

Required methods
----------------
execute_sql(sql: str) -> polars.DataFrame
Execute a SQL query and return results as a polars DataFrame.
register(name: str, df: polars.DataFrame, replace: bool = False) -> None
Register a DataFrame as a named table for SQL queries.
"""

def execute_sql(self, sql: str) -> pl.DataFrame: ...

def register(
self, name: str, df: pl.DataFrame, replace: bool = False
) -> None: ...


def _json_to_altair_chart(vegalite_json: str, **kwargs: Any) -> AltairChart:
"""Convert a Vega-Lite JSON string to the appropriate Altair chart type."""
spec = json.loads(vegalite_json)
Expand Down
84 changes: 65 additions & 19 deletions ggsql-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// See: https://github.com/PyO3/pyo3/issues/4327
#![allow(clippy::useless_conversion)]

use pyo3::create_exception;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList};
use std::io::Cursor;
Expand All @@ -12,6 +14,48 @@ use ggsql::validate::{validate as rust_validate, ValidationWarning};
use ggsql::writer::{VegaLiteWriter as RustVegaLiteWriter, Writer as RustWriter};
use ggsql::GgsqlError;

// ============================================================================
// Custom Exception Classes
// ============================================================================

// All subclass ValueError for backwards compatibility
create_exception!(
ggsql,
ParseError,
PyValueError,
"Raised on query syntax errors."
);
create_exception!(
ggsql,
ValidationError,
PyValueError,
"Raised on semantic validation errors."
);
create_exception!(
ggsql,
ReaderError,
PyValueError,
"Raised on data source errors."
);
create_exception!(
ggsql,
WriterError,
PyValueError,
"Raised on output generation errors."
);

/// Convert a GgsqlError to the appropriate typed Python exception.
fn ggsql_err_to_py(e: GgsqlError) -> PyErr {
let msg = e.to_string();
match e {
GgsqlError::ParseError(_) => PyErr::new::<ParseError, _>(msg),
GgsqlError::ValidationError(_) => PyErr::new::<ValidationError, _>(msg),
GgsqlError::ReaderError(_) => PyErr::new::<ReaderError, _>(msg),
GgsqlError::WriterError(_) => PyErr::new::<WriterError, _>(msg),
GgsqlError::InternalError(_) => PyErr::new::<PyValueError, _>(msg),
}
}

use polars::prelude::{DataFrame, IpcReader, IpcWriter, SerReader, SerWriter};

// ============================================================================
Expand Down Expand Up @@ -142,9 +186,13 @@ impl Reader for PyReaderBridge {
Python::attach(|py| {
let py_df =
polars_to_py(py, &df).map_err(|e| GgsqlError::ReaderError(e.to_string()))?;
let kwargs = PyDict::new(py);
kwargs
.set_item("replace", replace)
.map_err(|e| GgsqlError::ReaderError(e.to_string()))?;
self.obj
.bind(py)
.call_method1("register", (name, py_df, replace))
.call_method("register", (name, py_df), Some(&kwargs))
.map_err(|e| GgsqlError::ReaderError(format!("Reader.register() failed: {}", e)))?;
Ok(())
})
Expand Down Expand Up @@ -175,7 +223,7 @@ macro_rules! try_native_readers {
if let Ok(native) = $reader.downcast::<$native_type>() {
return native.borrow().inner.execute($query)
.map(|s| PySpec { inner: s })
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()));
.map_err(ggsql_err_to_py);
}
)*
}};
Expand Down Expand Up @@ -224,8 +272,8 @@ impl PyDuckDBReader {
/// If the connection string is invalid or the database cannot be opened.
#[new]
fn new(connection: &str) -> PyResult<Self> {
let inner = RustDuckDBReader::from_connection_string(connection)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let inner =
RustDuckDBReader::from_connection_string(connection).map_err(ggsql_err_to_py)?;
Ok(Self { inner })
}

Expand Down Expand Up @@ -255,7 +303,7 @@ impl PyDuckDBReader {
let rust_df = py_to_polars(py, df)?;
self.inner
.register(name, rust_df, replace)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
.map_err(ggsql_err_to_py)
}

/// Unregister a previously registered table.
Expand All @@ -270,9 +318,7 @@ impl PyDuckDBReader {
/// ValueError
/// If the table wasn't registered via this reader or unregistration fails.
fn unregister(&self, name: &str) -> PyResult<()> {
self.inner
.unregister(name)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
self.inner.unregister(name).map_err(ggsql_err_to_py)
}

/// Execute a SQL query and return the result as a DataFrame.
Expand All @@ -292,10 +338,7 @@ impl PyDuckDBReader {
/// ValueError
/// If the SQL is invalid or execution fails.
fn execute_sql(&self, py: Python<'_>, sql: &str) -> PyResult<Py<PyAny>> {
let df = self
.inner
.execute_sql(sql)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let df = self.inner.execute_sql(sql).map_err(ggsql_err_to_py)?;
polars_to_py(py, &df)
}

Expand Down Expand Up @@ -330,7 +373,7 @@ impl PyDuckDBReader {
self.inner
.execute(query)
.map(|s| PySpec { inner: s })
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
.map_err(ggsql_err_to_py)
}
}

Expand Down Expand Up @@ -391,9 +434,7 @@ impl PyVegaLiteWriter {
/// >>> writer = VegaLiteWriter()
/// >>> json_output = writer.render(spec)
fn render(&self, spec: &PySpec) -> PyResult<String> {
self.inner
.render(&spec.inner)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
self.inner.render(&spec.inner).map_err(ggsql_err_to_py)
}
}

Expand Down Expand Up @@ -657,8 +698,7 @@ impl PySpec {
/// If validation fails unexpectedly (not for syntax errors, which are captured).
#[pyfunction]
fn validate(query: &str) -> PyResult<PyValidated> {
let v = rust_validate(query)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
let v = rust_validate(query).map_err(ggsql_err_to_py)?;

Ok(PyValidated {
sql: v.sql().to_string(),
Expand Down Expand Up @@ -739,7 +779,7 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {
bridge
.execute(query)
.map(|s| PySpec { inner: s })
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))
.map_err(ggsql_err_to_py)
}

// ============================================================================
Expand All @@ -748,6 +788,12 @@ fn execute(query: &str, reader: &Bound<'_, PyAny>) -> PyResult<PySpec> {

#[pymodule]
fn _ggsql(m: &Bound<'_, PyModule>) -> PyResult<()> {
// Exceptions
m.add("ParseError", m.py().get_type::<ParseError>())?;
m.add("ValidationError", m.py().get_type::<ValidationError>())?;
m.add("ReaderError", m.py().get_type::<ReaderError>())?;
m.add("WriterError", m.py().get_type::<WriterError>())?;

// Classes
m.add_class::<PyDuckDBReader>()?;
m.add_class::<PyVegaLiteWriter>()?;
Expand Down
93 changes: 89 additions & 4 deletions ggsql-python/tests/test_ggsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def __init__(self):
def execute_sql(self, sql: str) -> pl.DataFrame:
return self.conn.execute(sql).pl()

def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None:
self.conn.register(name, df)

reader = RegisterReader()
Expand Down Expand Up @@ -453,7 +453,7 @@ def __init__(self):
def execute_sql(self, sql: str) -> pl.DataFrame:
return self.conn.execute(sql).pl()

def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None:
self.conn.register(name, df)

reader = DuckDBBackedReader()
Expand Down Expand Up @@ -484,7 +484,7 @@ def execute_sql(self, sql: str) -> pl.DataFrame:
self.execute_calls.append(sql)
return self.conn.execute(sql).pl()

def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None:
def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None:
self.conn.register(name, df)

reader = RecordingReader()
Expand Down Expand Up @@ -532,6 +532,92 @@ def unregister(self, name: str) -> None:
assert "point" in json_output


class TestExceptions:
"""Tests for typed exception classes."""

def test_parse_error_on_invalid_syntax(self):
"""Invalid syntax raises ParseError when executing."""
with pytest.raises(ggsql.ParseError):
reader = ggsql.DuckDBReader("duckdb://memory")
reader.execute("SELECT 1 AS x VISUALISE DRAW not_a_geom")

def test_parse_error_is_value_error(self):
"""ParseError is a subclass of ValueError for backwards compat."""
assert issubclass(ggsql.ParseError, ValueError)

def test_validation_error_on_missing_aesthetic(self):
"""Missing required aesthetic raises ValidationError."""
with pytest.raises(ggsql.ValidationError):
reader = ggsql.DuckDBReader("duckdb://memory")
reader.execute("SELECT 1 AS x VISUALISE DRAW point MAPPING x AS x")

def test_validation_error_is_value_error(self):
"""ValidationError is a subclass of ValueError for backwards compat."""
assert issubclass(ggsql.ValidationError, ValueError)

def test_reader_error_on_bad_sql(self):
"""Bad SQL raises ReaderError."""
with pytest.raises(ggsql.ReaderError):
reader = ggsql.DuckDBReader("duckdb://memory")
reader.execute(
"SELECT * FROM nonexistent_table VISUALISE DRAW point MAPPING x AS x, y AS y"
)

def test_reader_error_is_value_error(self):
"""ReaderError is a subclass of ValueError for backwards compat."""
assert issubclass(ggsql.ReaderError, ValueError)

def test_writer_error_is_value_error(self):
"""WriterError is a subclass of ValueError for backwards compat."""
assert issubclass(ggsql.WriterError, ValueError)

def test_all_exceptions_exported(self):
"""All exception classes are accessible from ggsql module."""
assert hasattr(ggsql, "ParseError")
assert hasattr(ggsql, "ValidationError")
assert hasattr(ggsql, "ReaderError")
assert hasattr(ggsql, "WriterError")


class TestReaderProtocol:
"""Tests for Reader protocol."""

def test_duckdb_reader_is_reader(self):
"""Native DuckDBReader satisfies the Reader protocol."""
reader = ggsql.DuckDBReader("duckdb://memory")
assert isinstance(reader, ggsql.Reader)

def test_custom_reader_is_reader(self):
"""Custom reader with correct methods satisfies the Reader protocol."""

class MyReader:
def execute_sql(self, sql: str) -> pl.DataFrame:
return pl.DataFrame({"x": [1]})

def register(
self, name: str, df: pl.DataFrame, replace: bool = False
) -> None:
pass

reader = MyReader()
assert isinstance(reader, ggsql.Reader)

def test_incomplete_reader_is_not_reader(self):
"""Object missing required methods is not a Reader."""

class NotAReader:
def execute_sql(self, sql: str) -> pl.DataFrame:
return pl.DataFrame({"x": [1]})
# Missing register()

obj = NotAReader()
assert not isinstance(obj, ggsql.Reader)

def test_reader_is_exported(self):
"""Reader is accessible from ggsql module."""
assert hasattr(ggsql, "Reader")


class TestVegaLiteWriterRenderChart:
"""Tests for VegaLiteWriter.render_chart() method."""

Expand Down Expand Up @@ -568,4 +654,3 @@ def test_render_chart_facet(self):
writer = ggsql.VegaLiteWriter()
chart = writer.render_chart(spec, validate=False)
assert isinstance(chart, altair.FacetChart)

Loading