From bcc8566d1b789d50ff23c023ca385f30d3ed94d0 Mon Sep 17 00:00:00 2001 From: renato2099 Date: Sun, 1 Jun 2025 22:06:38 +0200 Subject: [PATCH 1/3] Exposing FFI to python Exposing FFI to python Workin progress on python catalog Flushing out schema and catalog providers Adding implementation of python based catalog and schema providers Small updates after rebase --- Cargo.lock | 19 + Cargo.toml | 2 + examples/datafusion-ffi-example/Cargo.lock | 1 + examples/datafusion-ffi-example/Cargo.toml | 1 + .../python/tests/_test_catalog_provider.py | 60 +++ .../src/catalog_provider.rs | 179 +++++++ examples/datafusion-ffi-example/src/lib.rs | 3 + python/datafusion/__init__.py | 3 +- python/datafusion/catalog.py | 86 +++- python/datafusion/context.py | 15 + python/tests/test_catalog.py | 102 +++- src/catalog.rs | 437 ++++++++++++++++-- src/context.rs | 52 ++- src/functions.rs | 2 +- src/lib.rs | 10 +- 15 files changed, 899 insertions(+), 73 deletions(-) create mode 100644 examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py create mode 100644 examples/datafusion-ffi-example/src/catalog_provider.rs diff --git a/Cargo.lock b/Cargo.lock index 112167cb4..a3e9336cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -165,6 +165,12 @@ dependencies = [ "zstd", ] +[[package]] +name = "arc-swap" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" + [[package]] name = "arrayref" version = "0.3.9" @@ -1503,6 +1509,7 @@ dependencies = [ "datafusion-proto", "datafusion-substrait", "futures", + "log", "mimalloc", "object_store", "prost", @@ -1510,6 +1517,7 @@ dependencies = [ "pyo3", "pyo3-async-runtimes", "pyo3-build-config", + "pyo3-log", "tokio", "url", "uuid", @@ -2953,6 +2961,17 @@ dependencies = [ "pyo3-build-config", ] +[[package]] +name = "pyo3-log" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264" +dependencies = [ + "arc-swap", + "log", + "pyo3", +] + [[package]] name = "pyo3-macros" version = "0.24.2" diff --git a/Cargo.toml b/Cargo.toml index 4135e64e2..1f7895a50 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,7 @@ substrait = ["dep:datafusion-substrait"] tokio = { version = "1.45", features = ["macros", "rt", "rt-multi-thread", "sync"] } pyo3 = { version = "0.24", features = ["extension-module", "abi3", "abi3-py39"] } pyo3-async-runtimes = { version = "0.24", features = ["tokio-runtime"]} +pyo3-log = "0.12.4" arrow = { version = "55.1.0", features = ["pyarrow"] } datafusion = { version = "48.0.0", features = ["avro", "unicode_expressions"] } datafusion-substrait = { version = "48.0.0", optional = true } @@ -49,6 +50,7 @@ async-trait = "0.1.88" futures = "0.3" object_store = { version = "0.12.1", features = ["aws", "gcp", "azure", "http"] } url = "2" +log = "0.4.27" [build-dependencies] prost-types = "0.13.1" # keep in line with `datafusion-substrait` diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index 075ebd5a1..e5a1ca8d1 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -1448,6 +1448,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-schema", + "async-trait", "datafusion", "datafusion-ffi", "pyo3", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 0e17567b9..319163554 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } +async-trait = "0.1.88" [build-dependencies] pyo3-build-config = "0.23" diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py new file mode 100644 index 000000000..72aadf64c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext +from datafusion_ffi_example import MyCatalogProvider + + +def test_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_catalog" + expected_schema_name = "my_schema" + expected_table_name = "my_table" + expected_table_columns = ["units", "price"] + + catalog_provider = MyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, catalog_provider) + my_catalog = ctx.catalog(my_catalog_name) + + my_catalog_schemas = my_catalog.names() + assert expected_schema_name in my_catalog_schemas + my_database = my_catalog.database(expected_schema_name) + assert expected_table_name in my_database.names() + my_table = my_database.table(expected_table_name) + assert expected_table_columns == my_table.schema.names + + result = ctx.table( + f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}" + ).collect() + assert len(result) == 2 + + col0_result = [r.column(0) for r in result] + col1_result = [r.column(1) for r in result] + expected_col0 = [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array([5, 7], type=pa.int32()), + ] + expected_col1 = [ + pa.array([1, 2, 5], type=pa.float64()), + pa.array([1.5, 2.5], type=pa.float64()), + ] + assert col0_result == expected_col0 + assert col1_result == expected_col1 diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs new file mode 100644 index 000000000..54e61cf3e --- /dev/null +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::datatypes::Schema; +use async_trait::async_trait; +use datafusion::{ + catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider, + }, + common::exec_err, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use pyo3::types::PyCapsule; + +pub fn my_table() -> Arc { + use arrow::datatypes::{DataType, Field}; + use datafusion::common::record_batch; + + let schema = Arc::new(Schema::new(vec![ + Field::new("units", DataType::Int32, true), + Field::new("price", DataType::Float64, true), + ])); + + let partitions = vec![ + record_batch!( + ("units", Int32, vec![10, 20, 30]), + ("price", Float64, vec![1.0, 2.0, 5.0]) + ) + .unwrap(), + record_batch!( + ("units", Int32, vec![5, 7]), + ("price", Float64, vec![1.5, 2.5]) + ) + .unwrap(), + ]; + + Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) +} + +#[derive(Debug)] +pub struct FixedSchemaProvider { + inner: MemorySchemaProvider, +} + +impl Default for FixedSchemaProvider { + fn default() -> Self { + let inner = MemorySchemaProvider::new(); + + let table = my_table(); + + let _ = inner.register_table("my_table".to_string(), table).unwrap(); + + Self { inner } + } +} + +#[async_trait] +impl SchemaProvider for FixedSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table(&self, name: &str) -> Result>, DataFusionError> { + self.inner.table(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table(&self, name: &str) -> Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + +/// This catalog provider is intended only for unit tests. It prepopulates with one +/// schema and only allows for schemas named after four types of fruit. +#[pyclass( + name = "MyCatalogProvider", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug)] +pub(crate) struct MyCatalogProvider { + inner: MemoryCatalogProvider, +} + +impl Default for MyCatalogProvider { + fn default() -> Self { + let inner = MemoryCatalogProvider::new(); + + let schema_name: &str = "my_schema"; + let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default())); + + Self { inner } + } +} + +impl CatalogProvider for MyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + self.inner.register_schema(name, schema) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + self.inner.deregister_schema(name, cascade) + } +} + +#[pymethods] +impl MyCatalogProvider { + #[new] + pub fn new() -> Self { + Self { + inner: Default::default(), + } + } + + pub fn __datafusion_catalog_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider".into(); + let catalog_provider = + FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + + PyCapsule::new(py, catalog_provider, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index ae08c3b65..3a4cf2247 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::catalog_provider::MyCatalogProvider; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; use pyo3::prelude::*; +pub(crate) mod catalog_provider; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -26,5 +28,6 @@ pub(crate) mod table_provider; fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 4f7700251..4d385460f 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -32,7 +32,7 @@ from datafusion.col import col, column -from . import functions, object_store, substrait, unparser +from . import catalog, functions, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -91,6 +91,7 @@ "TableFunction", "WindowFrame", "WindowUDF", + "catalog", "col", "column", "common", diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index 67ab3ead2..bebd38161 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -26,11 +26,23 @@ if TYPE_CHECKING: import pyarrow as pa +try: + from warnings import deprecated # Python 3.13+ +except ImportError: + from typing_extensions import deprecated # Python 3.12 + + +__all__ = [ + "Catalog", + "Schema", + "Table", +] + class Catalog: """DataFusion data catalog.""" - def __init__(self, catalog: df_internal.Catalog) -> None: + def __init__(self, catalog: df_internal.catalog.RawCatalog) -> None: """This constructor is not typically called by the end user.""" self.catalog = catalog @@ -38,39 +50,79 @@ def __repr__(self) -> str: """Print a string representation of the catalog.""" return self.catalog.__repr__() - def names(self) -> list[str]: - """Returns the list of databases in this catalog.""" - return self.catalog.names() + def names(self) -> set[str]: + """This is an alias for `schema_names`.""" + return self.schema_names() + + def schema_names(self) -> set[str]: + """Returns the list of schemas in this catalog.""" + return self.catalog.schema_names() + + def schema(self, name: str = "public") -> Schema: + """Returns the database with the given ``name`` from this catalog.""" + schema = self.catalog.schema(name) + + return ( + Schema(schema) + if isinstance(schema, df_internal.catalog.RawSchema) + else schema + ) - def database(self, name: str = "public") -> Database: + @deprecated("Use `schema` instead.") + def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" - return Database(self.catalog.database(name)) + return self.schema(name) + def register_schema(self, name, schema) -> Schema | None: + """Register a schema with this catalog.""" + return self.catalog.register_schema(name, schema) -class Database: - """DataFusion Database.""" + def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None: + """Deregister a schema from this catalog.""" + return self.catalog.deregister_schema(name, cascade) - def __init__(self, db: df_internal.Database) -> None: + +class Schema: + """DataFusion Schema.""" + + def __init__(self, schema: df_internal.catalog.RawSchema) -> None: """This constructor is not typically called by the end user.""" - self.db = db + self._raw_schema = schema def __repr__(self) -> str: - """Print a string representation of the database.""" - return self.db.__repr__() + """Print a string representation of the schema.""" + return self._raw_schema.__repr__() def names(self) -> set[str]: - """Returns the list of all tables in this database.""" - return self.db.names() + """This is an alias for `table_names`.""" + return self.table_names() + + def table_names(self) -> set[str]: + """Returns the list of all tables in this schema.""" + return self._raw_schema.table_names def table(self, name: str) -> Table: - """Return the table with the given ``name`` from this database.""" - return Table(self.db.table(name)) + """Return the table with the given ``name`` from this schema.""" + return Table(self._raw_schema.table(name)) + + def register_table(self, name, table) -> None: + """Register a table provider in this schema.""" + return self._raw_schema.register_table(name, table) + + def deregister_table(self, name: str) -> None: + """Deregister a table provider from this schema.""" + return self._raw_schema.deregister_table(name) + + +@deprecated("Use `Schema` instead.") +class Database(Schema): + """See `Schema`.""" class Table: """DataFusion table.""" - def __init__(self, table: df_internal.Table) -> None: + def __init__(self, table: df_internal.catalog.RawTable) -> None: """This constructor is not typically called by the end user.""" self.table = table diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 4ed465c99..4dc9f42f0 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -78,6 +78,15 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + + def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -746,6 +755,12 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def register_catalog_provider( + self, name: str, provider: CatalogProviderExportable + ) -> None: + """Register a catalog provider.""" + self.ctx.register_catalog_provider(name, provider) + def register_table_provider( self, name: str, provider: TableProviderExportable ) -> None: diff --git a/python/tests/test_catalog.py b/python/tests/test_catalog.py index 23b328458..21b0a3e0a 100644 --- a/python/tests/test_catalog.py +++ b/python/tests/test_catalog.py @@ -15,8 +15,11 @@ # specific language governing permissions and limitations # under the License. +import datafusion as dfn import pyarrow as pa +import pyarrow.dataset as ds import pytest +from datafusion import SessionContext, Table # Note we take in `database` as a variable even though we don't use @@ -27,7 +30,7 @@ def test_basic(ctx, database): ctx.catalog("non-existent") default = ctx.catalog() - assert default.names() == ["public"] + assert default.names() == {"public"} for db in [default.database("public"), default.database()]: assert db.names() == {"csv1", "csv", "csv2"} @@ -41,3 +44,100 @@ def test_basic(ctx, database): pa.field("float", pa.float64(), nullable=True), ] ) + + +class CustomTableProvider: + def __init__(self): + pass + + +def create_dataset() -> pa.dataset.Dataset: + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + return ds.dataset([batch]) + + +class CustomSchemaProvider: + def __init__(self): + self.tables = {"table1": create_dataset()} + + def table_names(self) -> set[str]: + return set(self.tables.keys()) + + def register_table(self, name: str, table: Table): + self.tables[name] = table + + def deregister_table(self, name, cascade: bool = True): + del self.tables[name] + + +class CustomCatalogProvider: + def __init__(self): + self.schemas = {"my_schema": CustomSchemaProvider()} + + def schema_names(self) -> set[str]: + return set(self.schemas.keys()) + + def schema(self, name: str): + return self.schemas[name] + + def register_schema(self, name: str, schema: dfn.catalog.Schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade: bool): + del self.schemas[name] + + +def test_python_catalog_provider(ctx: SessionContext): + ctx.register_catalog_provider("my_catalog", CustomCatalogProvider()) + + # Check the default catalog provider + assert ctx.catalog("datafusion").names() == {"public"} + + my_catalog = ctx.catalog("my_catalog") + assert my_catalog.names() == {"my_schema"} + + my_catalog.register_schema("second_schema", CustomSchemaProvider()) + assert my_catalog.schema_names() == {"my_schema", "second_schema"} + + my_catalog.deregister_schema("my_schema") + assert my_catalog.schema_names() == {"second_schema"} + + +def test_python_schema_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.deregister_schema("public") + + catalog.register_schema("test_schema1", CustomSchemaProvider()) + assert catalog.names() == {"test_schema1"} + + catalog.register_schema("test_schema2", CustomSchemaProvider()) + catalog.deregister_schema("test_schema1") + assert catalog.names() == {"test_schema2"} + + +def test_python_table_provider(ctx: SessionContext): + catalog = ctx.catalog() + + catalog.register_schema("custom_schema", CustomSchemaProvider()) + schema = catalog.schema("custom_schema") + + assert schema.table_names() == {"table1"} + + schema.deregister_table("table1") + schema.register_table("table2", create_dataset()) + assert schema.table_names() == {"table2"} + + # Use the default schema instead of our custom schema + + schema = catalog.schema() + + schema.register_table("table3", create_dataset()) + assert schema.table_names() == {"table3"} + + schema.deregister_table("table3") + schema.register_table("table4", create_dataset()) + assert schema.table_names() == {"table4"} diff --git a/src/catalog.rs b/src/catalog.rs index 83f8d08cb..ba96ce471 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -15,44 +15,50 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; -use std::sync::Arc; - -use pyo3::exceptions::PyKeyError; -use pyo3::prelude::*; - -use crate::errors::{PyDataFusionError, PyDataFusionResult}; -use crate::utils::wait_for_future; +use crate::dataset::Dataset; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; +use crate::utils::{validate_pycapsule, wait_for_future}; +use async_trait::async_trait; +use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion_ffi::schema_provider::{FFI_SchemaProvider, ForeignSchemaProvider}; +use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use pyo3::exceptions::PyKeyError; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; +use pyo3::IntoPyObjectExt; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; -#[pyclass(name = "Catalog", module = "datafusion", subclass)] +#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)] pub struct PyCatalog { pub catalog: Arc, } -#[pyclass(name = "Database", module = "datafusion", subclass)] -pub struct PyDatabase { - pub database: Arc, +#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)] +pub struct PySchema { + pub schema: Arc, } -#[pyclass(name = "Table", module = "datafusion", subclass)] +#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)] pub struct PyTable { pub table: Arc, } -impl PyCatalog { - pub fn new(catalog: Arc) -> Self { +impl From> for PyCatalog { + fn from(catalog: Arc) -> Self { Self { catalog } } } -impl PyDatabase { - pub fn new(database: Arc) -> Self { - Self { database } +impl From> for PySchema { + fn from(schema: Arc) -> Self { + Self { schema } } } @@ -68,36 +74,93 @@ impl PyTable { #[pymethods] impl PyCatalog { - fn names(&self) -> Vec { - self.catalog.schema_names() + #[new] + fn new(catalog: PyObject) -> Self { + let catalog_provider = + Arc::new(RustWrappedPyCatalogProvider::new(catalog)) as Arc; + catalog_provider.into() + } + + fn schema_names(&self) -> HashSet { + self.catalog.schema_names().into_iter().collect() } #[pyo3(signature = (name="public"))] - fn database(&self, name: &str) -> PyResult { - match self.catalog.schema(name) { - Some(database) => Ok(PyDatabase::new(database)), - None => Err(PyKeyError::new_err(format!( - "Database with name {name} doesn't exist." - ))), - } + fn schema(&self, name: &str) -> PyResult { + let schema = self + .catalog + .schema(name) + .ok_or(PyKeyError::new_err(format!( + "Schema with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.schema_provider.clone_ref(py)), + None => PySchema::from(schema).into_py_any(py), + } + }) + } + + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { + let capsule = schema_provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPySchemaProvider::new(schema_provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self + .catalog + .register_schema(name, provider) + .map_err(py_datafusion_err)?; + + Ok(()) + } + + fn deregister_schema(&self, name: &str, cascade: bool) -> PyResult<()> { + let _ = self + .catalog + .deregister_schema(name, cascade) + .map_err(py_datafusion_err)?; + + Ok(()) } fn __repr__(&self) -> PyResult { - Ok(format!( - "Catalog(schema_names=[{}])", - self.names().join(";") - )) + let mut names: Vec = self.schema_names().into_iter().collect(); + names.sort(); + Ok(format!("Catalog(schema_names=[{}])", names.join(", "))) } } #[pymethods] -impl PyDatabase { - fn names(&self) -> HashSet { - self.database.table_names().into_iter().collect() +impl PySchema { + #[new] + fn new(schema_provider: PyObject) -> Self { + let schema_provider = + Arc::new(RustWrappedPySchemaProvider::new(schema_provider)) as Arc; + schema_provider.into() + } + + #[getter] + fn table_names(&self) -> HashSet { + self.schema.table_names().into_iter().collect() } fn table(&self, name: &str, py: Python) -> PyDataFusionResult { - if let Some(table) = wait_for_future(py, self.database.table(name))?? { + if let Some(table) = wait_for_future(py, self.schema.table(name))?? { Ok(PyTable::new(table)) } else { Err(PyDataFusionError::Common(format!( @@ -107,14 +170,44 @@ impl PyDatabase { } fn __repr__(&self) -> PyResult { - Ok(format!( - "Database(table_names=[{}])", - Vec::from_iter(self.names()).join(";") - )) + let mut names: Vec = self.table_names().into_iter().collect(); + names.sort(); + Ok(format!("Schema(table_names=[{}])", names.join(";"))) + } + + fn register_table(&self, name: &str, table_provider: Bound<'_, PyAny>) -> PyResult<()> { + let provider = if table_provider.hasattr("__datafusion_table_provider__")? { + let capsule = table_provider + .getattr("__datafusion_table_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let py = table_provider.py(); + let provider = Dataset::new(&table_provider, py)?; + Arc::new(provider) as Arc + }; + + let _ = self + .schema + .register_table(name.to_string(), provider) + .map_err(py_datafusion_err)?; + + Ok(()) } - // register_table - // deregister_table + fn deregister_table(&self, name: &str) -> PyResult<()> { + let _ = self + .schema + .deregister_table(name) + .map_err(py_datafusion_err)?; + + Ok(()) + } } #[pymethods] @@ -145,3 +238,265 @@ impl PyTable { // fn has_exact_statistics // fn supports_filter_pushdown } + +#[derive(Debug)] +pub(crate) struct RustWrappedPySchemaProvider { + schema_provider: PyObject, + owner_name: Option, +} + +impl RustWrappedPySchemaProvider { + pub fn new(schema_provider: PyObject) -> Self { + let owner_name = Python::with_gil(|py| { + schema_provider + .bind(py) + .getattr("owner_name") + .ok() + .map(|name| name.to_string()) + }); + + Self { + schema_provider, + owner_name, + } + } + + fn table_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let py_table_method = provider.getattr("table")?; + + let py_table = py_table_method.call((name,), None)?; + if py_table.is_none() { + return Ok(None); + } + + if py_table.hasattr("__datafusion_table_provider__")? { + let capsule = provider.getattr("__datafusion_table_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_table_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignTableProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?; + + Ok(Some(Arc::new(ds) as Arc)) + } + }) + } +} + +#[async_trait] +impl SchemaProvider for RustWrappedPySchemaProvider { + fn owner_name(&self) -> Option<&str> { + self.owner_name.as_deref() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + + provider + .getattr("table_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get table_names: {err}"); + Vec::default() + }) + }) + } + + async fn table( + &self, + name: &str, + ) -> datafusion::common::Result>, DataFusionError> { + self.table_inner(name).map_err(to_datafusion_err) + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> datafusion::common::Result>> { + let py_table = PyTable::new(table); + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let _ = provider + .call_method1("register_table", (name, py_table)) + .map_err(to_datafusion_err)?; + // Since the definition of `register_table` says that an error + // will be returned if the table already exists, there is no + // case where we want to return a table provider as output. + Ok(None) + }) + } + + fn deregister_table( + &self, + name: &str, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + let table = provider + .call_method1("deregister_table", (name,)) + .map_err(to_datafusion_err)?; + if table.is_none() { + return Ok(None); + } + + // If we can turn this table provider into a `Dataset`, return it. + // Otherwise, return None. + let dataset = match Dataset::new(&table, py) { + Ok(dataset) => Some(Arc::new(dataset) as Arc), + Err(_) => None, + }; + + Ok(dataset) + }) + } + + fn table_exist(&self, name: &str) -> bool { + Python::with_gil(|py| { + let provider = self.schema_provider.bind(py); + provider + .call_method1("table_exist", (name,)) + .and_then(|pyobj| pyobj.extract()) + .unwrap_or(false) + }) + } +} + +#[derive(Debug)] +pub(crate) struct RustWrappedPyCatalogProvider { + pub(crate) catalog_provider: PyObject, +} + +impl RustWrappedPyCatalogProvider { + pub fn new(catalog_provider: PyObject) -> Self { + Self { catalog_provider } + } + + fn schema_inner(&self, name: &str) -> PyResult>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + + let py_schema = provider.call_method1("schema", (name,))?; + if py_schema.is_none() { + return Ok(None); + } + + if py_schema.hasattr("__datafusion_schema_provider__")? { + let capsule = provider + .getattr("__datafusion_schema_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_schema_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignSchemaProvider = provider.into(); + + Ok(Some(Arc::new(provider) as Arc)) + } else { + let py_schema = RustWrappedPySchemaProvider::new(py_schema.into()); + + Ok(Some(Arc::new(py_schema) as Arc)) + } + }) + } +} + +#[async_trait] +impl CatalogProvider for RustWrappedPyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + provider + .getattr("schema_names") + .and_then(|names| names.extract::>()) + .unwrap_or_else(|err| { + log::error!("Unable to get schema_names: {err}"); + Vec::default() + }) + }) + } + + fn schema(&self, name: &str) -> Option> { + self.schema_inner(name).unwrap_or_else(|err| { + log::error!("CatalogProvider schema returned error: {err}"); + None + }) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> datafusion::common::Result>> { + // JRIGHT HERE + // let py_schema: PySchema = schema.into(); + Python::with_gil(|py| { + let py_schema = match schema + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => wrapped_schema.schema_provider.as_any(), + None => &PySchema::from(schema) + .into_py_any(py) + .map_err(to_datafusion_err)?, + }; + + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("register_schema", (name, py_schema)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> datafusion::common::Result>> { + Python::with_gil(|py| { + let provider = self.catalog_provider.bind(py); + let schema = provider + .call_method1("deregister_schema", (name, cascade)) + .map_err(to_datafusion_err)?; + if schema.is_none() { + return Ok(None); + } + + let schema = Arc::new(RustWrappedPySchemaProvider::new(schema.into())) + as Arc; + + Ok(Some(schema)) + }) + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/src/context.rs b/src/context.rs index 55c92a8fa..c92c84acc 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable}; +use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::CatalogProvider; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -69,8 +70,10 @@ use datafusion::physical_plan::SendableRecordBatchStream; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; +use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; +use pyo3::IntoPyObjectExt; use tokio::task::JoinHandle; /// Configuration options for a SessionContext @@ -609,6 +612,31 @@ impl PySessionContext { Ok(()) } + pub fn register_catalog_provider( + &mut self, + name: &str, + provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + let provider = if provider.hasattr("__datafusion_catalog_provider__")? { + let capsule = provider + .getattr("__datafusion_catalog_provider__")? + .call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignCatalogProvider = provider.into(); + Arc::new(provider) as Arc + } else { + let provider = RustWrappedPyCatalogProvider::new(provider.into()); + Arc::new(provider) as Arc + }; + + let _ = self.ctx.register_catalog(name, provider); + + Ok(()) + } + /// Construct datafusion dataframe from Arrow Table pub fn register_table_provider( &mut self, @@ -826,14 +854,20 @@ impl PySessionContext { } #[pyo3(signature = (name="datafusion"))] - pub fn catalog(&self, name: &str) -> PyResult { - match self.ctx.catalog(name) { - Some(catalog) => Ok(PyCatalog::new(catalog)), - None => Err(PyKeyError::new_err(format!( - "Catalog with name {} doesn't exist.", - &name, - ))), - } + pub fn catalog(&self, name: &str) -> PyResult { + let catalog = self.ctx.catalog(name).ok_or(PyKeyError::new_err(format!( + "Catalog with name {name} doesn't exist." + )))?; + + Python::with_gil(|py| { + match catalog + .as_any() + .downcast_ref::() + { + Some(wrapped_schema) => Ok(wrapped_schema.catalog_provider.clone_ref(py)), + None => PyCatalog::from(catalog).into_py_any(py), + } + }) } pub fn tables(&self) -> HashSet { diff --git a/src/functions.rs b/src/functions.rs index b2bafcb65..b40500b8b 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -937,7 +937,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(left))?; m.add_wrapped(wrap_pyfunction!(length))?; m.add_wrapped(wrap_pyfunction!(ln))?; - m.add_wrapped(wrap_pyfunction!(log))?; + m.add_wrapped(wrap_pyfunction!(self::log))?; m.add_wrapped(wrap_pyfunction!(log10))?; m.add_wrapped(wrap_pyfunction!(log2))?; m.add_wrapped(wrap_pyfunction!(lower))?; diff --git a/src/lib.rs b/src/lib.rs index 7dced1fbd..728490938 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,10 +77,10 @@ pub(crate) struct TokioRuntime(tokio::runtime::Runtime); /// datafusion directory. #[pymodule] fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + // Initialize logging + pyo3_log::init(); + // Register the python classes - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -96,6 +96,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + let catalog = PyModule::new(py, "catalog")?; + catalog::init_module(&catalog)?; + m.add_submodule(&catalog)?; + // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new(py, "common")?; common::init_module(&common)?; From 987b6c7a44b7e389f00f539652959e5ed240f62c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 15 Feb 2025 13:37:07 -0500 Subject: [PATCH 2/3] Intermediate work adding ffi scalar udf Add scalar UDF and example Add aggregate udf via ffi Initial commit for window ffi integration Remove patch --- docs/source/contributor-guide/ffi.rst | 2 +- examples/datafusion-ffi-example/Cargo.lock | 217 ++++++++++-------- examples/datafusion-ffi-example/Cargo.toml | 8 +- .../python/tests/_test_aggregate_udf.py | 77 +++++++ .../python/tests/_test_scalar_udf.py | 70 ++++++ .../python/tests/_test_window_udf.py | 89 +++++++ .../src/aggregate_udf.rs | 81 +++++++ examples/datafusion-ffi-example/src/lib.rs | 9 + .../datafusion-ffi-example/src/scalar_udf.rs | 91 ++++++++ .../datafusion-ffi-example/src/window_udf.rs | 81 +++++++ python/datafusion/user_defined.py | 107 ++++++++- src/functions.rs | 2 +- src/udaf.rs | 31 ++- src/udf.rs | 25 +- src/udwf.rs | 27 ++- 15 files changed, 805 insertions(+), 112 deletions(-) create mode 100644 examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py create mode 100644 examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py create mode 100644 examples/datafusion-ffi-example/python/tests/_test_window_udf.py create mode 100644 examples/datafusion-ffi-example/src/aggregate_udf.rs create mode 100644 examples/datafusion-ffi-example/src/scalar_udf.rs create mode 100644 examples/datafusion-ffi-example/src/window_udf.rs diff --git a/docs/source/contributor-guide/ffi.rst b/docs/source/contributor-guide/ffi.rst index c1f9806b3..a40af1234 100644 --- a/docs/source/contributor-guide/ffi.rst +++ b/docs/source/contributor-guide/ffi.rst @@ -176,7 +176,7 @@ By convention the ``datafusion-python`` library expects a Python object that has ``TableProvider`` PyCapsule to have this capsule accessible by calling a function named ``__datafusion_table_provider__``. You can see a complete working example of how to share a ``TableProvider`` from one python library to DataFusion Python in the -`repository examples folder `_. +`repository examples folder `_. This section has been written using ``TableProvider`` as an example. It is the first extension that has been written using this approach and the most thoroughly implemented. diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index e5a1ca8d1..1b4ca6bee 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -323,6 +323,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73a47aa0c771b5381de2b7f16998d351a6f4eb839f1e13d48353e17e873d969b" dependencies = [ "bitflags", + "serde", + "serde_json", ] [[package]] @@ -748,9 +750,9 @@ dependencies = [ [[package]] name = "datafusion" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffe060b978f74ab446be722adb8a274e052e005bf6dfd171caadc3abaad10080" +checksum = "cc6cb8c2c81eada072059983657d6c9caf3fddefc43b4a65551d243253254a96" dependencies = [ "arrow", "arrow-ipc", @@ -775,7 +777,6 @@ dependencies = [ "datafusion-functions-nested", "datafusion-functions-table", "datafusion-functions-window", - "datafusion-macros", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-expr-common", @@ -790,7 +791,7 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand", + "rand 0.9.1", "regex", "sqlparser", "tempfile", @@ -803,9 +804,9 @@ dependencies = [ [[package]] name = "datafusion-catalog" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61fe34f401bd03724a1f96d12108144f8cd495a3cdda2bf5e091822fb80b7e66" +checksum = "b7be8d1b627843af62e447396db08fe1372d882c0eb8d0ea655fd1fbc33120ee" dependencies = [ "arrow", "async-trait", @@ -829,9 +830,9 @@ dependencies = [ [[package]] name = "datafusion-catalog-listing" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4411b8e3bce5e0fc7521e44f201def2e2d5d1b5f176fb56e8cdc9942c890f00" +checksum = "38ab16c5ae43f65ee525fc493ceffbc41f40dee38b01f643dfcfc12959e92038" dependencies = [ "arrow", "async-trait", @@ -852,9 +853,9 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0734015d81c8375eb5d4869b7f7ecccc2ee8d6cb81948ef737cd0e7b743bd69c" +checksum = "d3d56b2ac9f476b93ca82e4ef5fb00769c8a3f248d12b4965af7e27635fa7e12" dependencies = [ "ahash", "arrow", @@ -876,9 +877,9 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5167bb1d2ccbb87c6bc36c295274d7a0519b14afcfdaf401d53cbcaa4ef4968b" +checksum = "16015071202d6133bc84d72756176467e3e46029f3ce9ad2cb788f9b1ff139b2" dependencies = [ "futures", "log", @@ -887,9 +888,9 @@ dependencies = [ [[package]] name = "datafusion-datasource" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e602dcdf2f50c2abf297cc2203c73531e6f48b29516af7695d338cf2a778b1" +checksum = "b77523c95c89d2a7eb99df14ed31390e04ab29b43ff793e562bdc1716b07e17b" dependencies = [ "arrow", "async-compression", @@ -912,7 +913,7 @@ dependencies = [ "log", "object_store", "parquet", - "rand", + "rand 0.9.1", "tempfile", "tokio", "tokio-util", @@ -923,9 +924,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-csv" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bb2253952dc32296ed5b84077cb2e0257fea4be6373e1c376426e17ead4ef6" +checksum = "40d25c5e2c0ebe8434beeea997b8e88d55b3ccc0d19344293f2373f65bc524fc" dependencies = [ "arrow", "async-trait", @@ -948,9 +949,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-json" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8c7f47a5d2fe03bfa521ec9bafdb8a5c82de8377f60967c3663f00c8790352" +checksum = "3dc6959e1155741ab35369e1dc7673ba30fc45ed568fad34c01b7cb1daeb4d4c" dependencies = [ "arrow", "async-trait", @@ -973,9 +974,9 @@ dependencies = [ [[package]] name = "datafusion-datasource-parquet" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27d15868ea39ed2dc266728b554f6304acd473de2142281ecfa1294bb7415923" +checksum = "b7a6afdfe358d70f4237f60eaef26ae5a1ce7cb2c469d02d5fc6c7fd5d84e58b" dependencies = [ "arrow", "async-trait", @@ -998,21 +999,21 @@ dependencies = [ "object_store", "parking_lot", "parquet", - "rand", + "rand 0.9.1", "tokio", ] [[package]] name = "datafusion-doc" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a91f8c2c5788ef32f48ff56c68e5b545527b744822a284373ac79bba1ba47292" +checksum = "9bcd8a3e3e3d02ea642541be23d44376b5d5c37c2938cce39b3873cdf7186eea" [[package]] name = "datafusion-execution" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06f004d100f49a3658c9da6fb0c3a9b760062d96cd4ad82ccc3b7b69a9fb2f84" +checksum = "670da1d45d045eee4c2319b8c7ea57b26cf48ab77b630aaa50b779e406da476a" dependencies = [ "arrow", "dashmap", @@ -1022,16 +1023,16 @@ dependencies = [ "log", "object_store", "parking_lot", - "rand", + "rand 0.9.1", "tempfile", "url", ] [[package]] name = "datafusion-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a4e4ce3802609be38eeb607ee72f6fe86c3091460de9dbfae9e18db423b3964" +checksum = "b3a577f64bdb7e2cc4043cd97f8901d8c504711fde2dbcb0887645b00d7c660b" dependencies = [ "arrow", "chrono", @@ -1050,9 +1051,9 @@ dependencies = [ [[package]] name = "datafusion-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "422ac9cf3b22bbbae8cdf8ceb33039107fde1b5492693168f13bd566b1bcc839" +checksum = "51b7916806ace3e9f41884f230f7f38ebf0e955dfbd88266da1826f29a0b9a6a" dependencies = [ "arrow", "datafusion-common", @@ -1063,9 +1064,9 @@ dependencies = [ [[package]] name = "datafusion-ffi" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cf3fe9ab492c56daeb7beed526690d33622d388b8870472e0b7b7f55490338c" +checksum = "980cca31de37f5dadf7ea18e4ffc2b6833611f45bed5ef9de0831d2abb50f1ef" dependencies = [ "abi_stable", "arrow", @@ -1073,7 +1074,9 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", + "datafusion-proto-common", "futures", "log", "prost", @@ -1081,11 +1084,25 @@ dependencies = [ "tokio", ] +[[package]] +name = "datafusion-ffi-example" +version = "0.2.0" +dependencies = [ + "arrow", + "arrow-array", + "arrow-schema", + "async-trait", + "datafusion", + "datafusion-ffi", + "pyo3", + "pyo3-build-config", +] + [[package]] name = "datafusion-functions" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ddf0a0a2db5d2918349c978d42d80926c6aa2459cd8a3c533a84ec4bb63479e" +checksum = "7fb31c9dc73d3e0c365063f91139dc273308f8a8e124adda9898db8085d68357" dependencies = [ "arrow", "arrow-buffer", @@ -1103,7 +1120,7 @@ dependencies = [ "itertools", "log", "md-5", - "rand", + "rand 0.9.1", "regex", "sha2", "unicode-segmentation", @@ -1112,9 +1129,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "408a05dafdc70d05a38a29005b8b15e21b0238734dab1e98483fcb58038c5aba" +checksum = "ebb72c6940697eaaba9bd1f746a697a07819de952b817e3fb841fb75331ad5d4" dependencies = [ "ahash", "arrow", @@ -1133,9 +1150,9 @@ dependencies = [ [[package]] name = "datafusion-functions-aggregate-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "756d21da2dd6c9bef97af1504970ff56cbf35d03fbd4ffd62827f02f4d2279d4" +checksum = "d7fdc54656659e5ecd49bf341061f4156ab230052611f4f3609612a0da259696" dependencies = [ "ahash", "arrow", @@ -1146,9 +1163,9 @@ dependencies = [ [[package]] name = "datafusion-functions-nested" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d8d50f6334b378930d992d801a10ac5b3e93b846b39e4a05085742572844537" +checksum = "fad94598e3374938ca43bca6b675febe557e7a14eb627d617db427d70d65118b" dependencies = [ "arrow", "arrow-ord", @@ -1167,9 +1184,9 @@ dependencies = [ [[package]] name = "datafusion-functions-table" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc9a97220736c8fff1446e936be90d57216c06f28969f9ffd3b72ac93c958c8a" +checksum = "de2fc6c2946da5cab8364fb28b5cac3115f0f3a87960b235ed031c3f7e2e639b" dependencies = [ "arrow", "async-trait", @@ -1183,10 +1200,11 @@ dependencies = [ [[package]] name = "datafusion-functions-window" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cefc2d77646e1aadd1d6a9c40088937aedec04e68c5f0465939912e1291f8193" +checksum = "3e5746548a8544870a119f556543adcd88fe0ba6b93723fe78ad0439e0fbb8b4" dependencies = [ + "arrow", "datafusion-common", "datafusion-doc", "datafusion-expr", @@ -1200,9 +1218,9 @@ dependencies = [ [[package]] name = "datafusion-functions-window-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd4aff082c42fa6da99ce0698c85addd5252928c908eb087ca3cfa64ff16b313" +checksum = "dcbe9404382cda257c434f22e13577bee7047031dfdb6216dd5e841b9465e6fe" dependencies = [ "datafusion-common", "datafusion-physical-expr-common", @@ -1210,9 +1228,9 @@ dependencies = [ [[package]] name = "datafusion-macros" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df6f88d7ee27daf8b108ba910f9015176b36fbc72902b1ca5c2a5f1d1717e1a1" +checksum = "8dce50e3b637dab0d25d04d2fe79dfdca2b257eabd76790bffd22c7f90d700c8" dependencies = [ "datafusion-expr", "quote", @@ -1221,9 +1239,9 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "084d9f979c4b155346d3c34b18f4256e6904ded508e9554d90fed416415c3515" +checksum = "03cfaacf06445dc3bbc1e901242d2a44f2cae99a744f49f3fefddcee46240058" dependencies = [ "arrow", "chrono", @@ -1240,9 +1258,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64c536062b0076f4e30084065d805f389f9fe38af0ca75bcbac86bc5e9fbab65" +checksum = "1908034a89d7b2630898e06863583ae4c00a0dd310c1589ca284195ee3f7f8a6" dependencies = [ "ahash", "arrow", @@ -1262,9 +1280,9 @@ dependencies = [ [[package]] name = "datafusion-physical-expr-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8a92b53b3193fac1916a1c5b8e3f4347c526f6822e56b71faa5fb372327a863" +checksum = "47b7a12dd59ea07614b67dbb01d85254fbd93df45bcffa63495e11d3bdf847df" dependencies = [ "ahash", "arrow", @@ -1276,9 +1294,9 @@ dependencies = [ [[package]] name = "datafusion-physical-optimizer" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa0a5ac94c7cf3da97bedabd69d6bbca12aef84b9b37e6e9e8c25286511b5e2" +checksum = "4371cc4ad33978cc2a8be93bd54a232d3f2857b50401a14631c0705f3f910aae" dependencies = [ "arrow", "datafusion-common", @@ -1295,9 +1313,9 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "690c615db468c2e5fe5085b232d8b1c088299a6c63d87fd960a354a71f7acb55" +checksum = "dc47bc33025757a5c11f2cd094c5b6b5ed87f46fa33c023e6fdfa25fcbfade23" dependencies = [ "ahash", "arrow", @@ -1325,9 +1343,9 @@ dependencies = [ [[package]] name = "datafusion-proto" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a1afb2bdb05de7ff65be6883ebfd4ec027bd9f1f21c46aa3afd01927160a83" +checksum = "d8f5d9acd7d96e3bf2a7bb04818373cab6e51de0356e3694b94905fee7b4e8b6" dependencies = [ "arrow", "chrono", @@ -1341,9 +1359,9 @@ dependencies = [ [[package]] name = "datafusion-proto-common" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35b7a5876ebd6b564fb9a1fd2c3a2a9686b787071a256b47e4708f0916f9e46f" +checksum = "09ecb5ec152c4353b60f7a5635489834391f7a291d2b39a4820cd469e318b78e" dependencies = [ "arrow", "datafusion-common", @@ -1352,9 +1370,9 @@ dependencies = [ [[package]] name = "datafusion-session" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad229a134c7406c057ece00c8743c0c34b97f4e72f78b475fe17b66c5e14fa4f" +checksum = "d7485da32283985d6b45bd7d13a65169dcbe8c869e25d01b2cfbc425254b4b49" dependencies = [ "arrow", "async-trait", @@ -1376,9 +1394,9 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "47.0.0" +version = "48.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64f6ab28b72b664c21a27b22a2ff815fd390ed224c26e89a93b5a8154a4e8607" +checksum = "a466b15632befddfeac68c125f0260f569ff315c6831538cbb40db754134e0df" dependencies = [ "arrow", "bigdecimal", @@ -1441,20 +1459,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" -[[package]] -name = "ffi-table-provider" -version = "0.1.0" -dependencies = [ - "arrow", - "arrow-array", - "arrow-schema", - "async-trait", - "datafusion", - "datafusion-ffi", - "pyo3", - "pyo3-build-config", -] - [[package]] name = "fixedbitset" version = "0.5.7" @@ -1488,6 +1492,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1666,6 +1676,11 @@ name = "hashbrown" version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] [[package]] name = "heck" @@ -2271,12 +2286,14 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "petgraph" -version = "0.7.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +checksum = "54acf3a685220b533e437e264e4d932cfbdc4cc7ec0cd232ed73c08d03b8a7ca" dependencies = [ "fixedbitset", + "hashbrown 0.15.3", "indexmap", + "serde", ] [[package]] @@ -2305,7 +2322,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] @@ -2484,19 +2501,27 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.3", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", ] [[package]] @@ -2504,8 +2529,14 @@ name = "rand_core" version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.2.16", + "getrandom 0.3.3", ] [[package]] @@ -3032,9 +3063,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", "js-sys", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 319163554..b26ab48e3 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -16,13 +16,13 @@ # under the License. [package] -name = "ffi-table-provider" -version = "0.1.0" +name = "datafusion-ffi-example" +version = "0.2.0" edition = "2021" [dependencies] -datafusion = { version = "47.0.0" } -datafusion-ffi = { version = "47.0.0" } +datafusion = { version = "48.0.0" } +datafusion-ffi = { version = "48.0.0" } pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] } arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } diff --git a/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py new file mode 100644 index 000000000..7ea6b295c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_aggregate_udf.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udaf +from datafusion_ffi_example import MySumUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + # Pick numbers here so we get the same value in both groups + # since we cannot be certain of the output order of batches + batch = pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3, None], type=pa.int64()), + pa.array([1, 1, 2, 2], type=pa.int64()), + ], + names=["a", "b"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_aggregate_register(): + ctx = setup_context_with_table() + my_udaf = udaf(MySumUDF()) + ctx.register_udaf(my_udaf) + + result = ctx.sql("select my_custom_sum(a) from test_table group by b").collect() + + assert len(result) == 2 + assert result[0].num_columns == 1 + + result = [r.column(0) for r in result] + expected = [ + pa.array([3], type=pa.int64()), + pa.array([3], type=pa.int64()), + ] + + assert result == expected + + +def test_ffi_aggregate_call_directly(): + ctx = setup_context_with_table() + my_udaf = udaf(MySumUDF()) + + result = ( + ctx.table("test_table").aggregate([col("b")], [my_udaf(col("a"))]).collect() + ) + + assert len(result) == 2 + assert result[0].num_columns == 2 + + result = [r.column(1) for r in result] + expected = [ + pa.array([3], type=pa.int64()), + pa.array([3], type=pa.int64()), + ] + + assert result == expected diff --git a/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py new file mode 100644 index 000000000..0c949c34a --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_scalar_udf.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udf +from datafusion_ffi_example import IsNullUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, None])], + names=["a"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_scalar_register(): + ctx = setup_context_with_table() + my_udf = udf(IsNullUDF()) + ctx.register_udf(my_udf) + + result = ctx.sql("select my_custom_is_null(a) from test_table").collect() + + assert len(result) == 1 + assert result[0].num_columns == 1 + print(result) + + result = [r.column(0) for r in result] + expected = [ + pa.array([False, False, False, True], type=pa.bool_()), + ] + + assert result == expected + + +def test_ffi_scalar_call_directly(): + ctx = setup_context_with_table() + my_udf = udf(IsNullUDF()) + + result = ctx.table("test_table").select(my_udf(col("a"))).collect() + + assert len(result) == 1 + assert result[0].num_columns == 1 + print(result) + + result = [r.column(0) for r in result] + expected = [ + pa.array([False, False, False, True], type=pa.bool_()), + ] + + assert result == expected diff --git a/examples/datafusion-ffi-example/python/tests/_test_window_udf.py b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py new file mode 100644 index 000000000..7d96994b9 --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_window_udf.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, col, udwf +from datafusion_ffi_example import MyRankUDF + + +def setup_context_with_table(): + ctx = SessionContext() + + # Pick numbers here so we get the same value in both groups + # since we cannot be certain of the output order of batches + batch = pa.RecordBatch.from_arrays( + [ + pa.array([40, 10, 30, 20], type=pa.int64()), + ], + names=["a"], + ) + ctx.register_record_batches("test_table", [[batch]]) + return ctx + + +def test_ffi_window_register(): + ctx = setup_context_with_table() + my_udwf = udwf(MyRankUDF()) + ctx.register_udwf(my_udwf) + + result = ctx.sql( + "select a, my_custom_rank() over (order by a) from test_table" + ).collect() + assert len(result) == 1 + assert result[0].num_columns == 2 + + results = [ + (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4) + ] + results.sort() + + expected = [ + (10, 1), + (20, 2), + (30, 3), + (40, 4), + ] + assert results == expected + + +def test_ffi_window_call_directly(): + ctx = setup_context_with_table() + my_udwf = udwf(MyRankUDF()) + + result = ( + ctx.table("test_table") + .select(col("a"), my_udwf().order_by(col("a")).build()) + .collect() + ) + + assert len(result) == 1 + assert result[0].num_columns == 2 + + results = [ + (result[0][0][idx].as_py(), result[0][1][idx].as_py()) for idx in range(4) + ] + results.sort() + + expected = [ + (10, 1), + (20, 2), + (30, 3), + (40, 4), + ] + assert results == expected diff --git a/examples/datafusion-ffi-example/src/aggregate_udf.rs b/examples/datafusion-ffi-example/src/aggregate_udf.rs new file mode 100644 index 000000000..9481fe9c6 --- /dev/null +++ b/examples/datafusion-ffi-example/src/aggregate_udf.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::DataType; +use datafusion::error::Result as DataFusionResult; +use datafusion::functions_aggregate::sum::Sum; +use datafusion::logical_expr::function::AccumulatorArgs; +use datafusion::logical_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; +use datafusion_ffi::udaf::FFI_AggregateUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "MySumUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct MySumUDF { + inner: Arc, +} + +#[pymethods] +impl MySumUDF { + #[new] + fn new() -> Self { + Self { + inner: Arc::new(Sum::new()), + } + } + + fn __datafusion_aggregate_udf__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_aggregate_udf".into(); + + let func = Arc::new(AggregateUDF::from(self.clone())); + let provider = FFI_AggregateUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl AggregateUDFImpl for MySumUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_sum" + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + self.inner.return_type(arg_types) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> DataFusionResult> { + self.inner.accumulator(acc_args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.inner.coerce_types(arg_types) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 3a4cf2247..f5f96cd49 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,19 +15,28 @@ // specific language governing permissions and limitations // under the License. +use crate::aggregate_udf::MySumUDF; use crate::catalog_provider::MyCatalogProvider; +use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; +use crate::window_udf::MyRankUDF; use pyo3::prelude::*; +pub(crate) mod aggregate_udf; pub(crate) mod catalog_provider; +pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; +pub(crate) mod window_udf; #[pymodule] fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/examples/datafusion-ffi-example/src/scalar_udf.rs b/examples/datafusion-ffi-example/src/scalar_udf.rs new file mode 100644 index 000000000..727666638 --- /dev/null +++ b/examples/datafusion-ffi-example/src/scalar_udf.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{Array, BooleanArray}; +use arrow_schema::DataType; +use datafusion::common::ScalarValue; +use datafusion::error::Result as DataFusionResult; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, + Volatility, +}; +use datafusion_ffi::udf::FFI_ScalarUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "IsNullUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct IsNullUDF { + signature: Signature, +} + +#[pymethods] +impl IsNullUDF { + #[new] + fn new() -> Self { + Self { + signature: Signature::new(TypeSignature::Any(1), Volatility::Immutable), + } + } + + fn __datafusion_scalar_udf__<'py>(&self, py: Python<'py>) -> PyResult> { + let name = cr"datafusion_scalar_udf".into(); + + let func = Arc::new(ScalarUDF::from(self.clone())); + let provider = FFI_ScalarUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl ScalarUDFImpl for IsNullUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_is_null" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + let input = &args.args[0]; + + Ok(match input { + ColumnarValue::Array(arr) => match arr.is_nullable() { + true => { + let nulls = arr.nulls().unwrap(); + let nulls = BooleanArray::from_iter(nulls.iter().map(|x| Some(!x))); + ColumnarValue::Array(Arc::new(nulls)) + } + false => ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))), + }, + ColumnarValue::Scalar(sv) => { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(sv == &ScalarValue::Null))) + } + }) + } +} diff --git a/examples/datafusion-ffi-example/src/window_udf.rs b/examples/datafusion-ffi-example/src/window_udf.rs new file mode 100644 index 000000000..e0d397956 --- /dev/null +++ b/examples/datafusion-ffi-example/src/window_udf.rs @@ -0,0 +1,81 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{DataType, FieldRef}; +use datafusion::error::Result as DataFusionResult; +use datafusion::functions_window::rank::rank_udwf; +use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; +use datafusion::logical_expr::{PartitionEvaluator, Signature, WindowUDF, WindowUDFImpl}; +use datafusion_ffi::udwf::FFI_WindowUDF; +use pyo3::types::PyCapsule; +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::any::Any; +use std::sync::Arc; + +#[pyclass(name = "MyRankUDF", module = "datafusion_ffi_example", subclass)] +#[derive(Debug, Clone)] +pub(crate) struct MyRankUDF { + inner: Arc, +} + +#[pymethods] +impl MyRankUDF { + #[new] + fn new() -> Self { + Self { inner: rank_udwf() } + } + + fn __datafusion_window_udf__<'py>(&self, py: Python<'py>) -> PyResult> { + let name = cr"datafusion_window_udf".into(); + + let func = Arc::new(WindowUDF::from(self.clone())); + let provider = FFI_WindowUDF::from(func); + + PyCapsule::new(py, provider, Some(name)) + } +} + +impl WindowUDFImpl for MyRankUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "my_custom_rank" + } + + fn signature(&self) -> &Signature { + self.inner.signature() + } + + fn partition_evaluator( + &self, + partition_evaluator_args: PartitionEvaluatorArgs, + ) -> DataFusionResult> { + self.inner + .inner() + .partition_evaluator(partition_evaluator_args) + } + + fn field(&self, field_args: WindowUDFFieldArgs) -> DataFusionResult { + self.inner.inner().field(field_args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.inner.coerce_types(arg_types) + } +} diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index dd634c7fb..bd686acbb 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -22,7 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload import pyarrow as pa @@ -77,6 +77,12 @@ def __str__(self) -> str: return self.name.lower() +class ScalarUDFExportable(Protocol): + """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.""" + + def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105 + + class ScalarUDF: """Class for performing scalar user-defined functions (UDF). @@ -96,6 +102,9 @@ def __init__( See helper method :py:func:`udf` for argument details. """ + if hasattr(func, "__datafusion_scalar_udf__"): + self._udf = df_internal.ScalarUDF.from_pycapsule(func) + return if isinstance(input_types, pa.DataType): input_types = [input_types] self._udf = df_internal.ScalarUDF( @@ -134,6 +143,10 @@ def udf( name: Optional[str] = None, ) -> ScalarUDF: ... + @overload + @staticmethod + def udf(func: ScalarUDFExportable) -> ScalarUDF: ... + @staticmethod def udf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Function (UDF). @@ -147,7 +160,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417 Args: func (Callable, optional): Only needed when calling as a function. - Skip this argument when using ``udf`` as a decorator. + Skip this argument when using `udf` as a decorator. If you have a Rust + backed ScalarUDF within a PyCapsule, you can pass this parameter + and ignore the rest. They will be determined directly from the + underlying function. See the online documentation for more information. input_types (list[pa.DataType]): The data types of the arguments to ``func``. This list must be of the same length as the number of arguments. @@ -215,12 +231,31 @@ def wrapper(*args: Any, **kwargs: Any): return decorator + if hasattr(args[0], "__datafusion_scalar_udf__"): + return ScalarUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return _function(*args, **kwargs) # Case 2: Used as a decorator with parameters return _decorator(*args, **kwargs) + @staticmethod + def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF: + """Create a Scalar UDF from ScalarUDF PyCapsule object. + + This function will instantiate a Scalar UDF that uses a DataFusion + ScalarUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return ScalarUDF( + name=name, + func=func, + input_types=None, + return_type=None, + volatility=None, + ) + class Accumulator(metaclass=ABCMeta): """Defines how an :py:class:`AggregateUDF` accumulates values.""" @@ -242,6 +277,12 @@ def evaluate(self) -> pa.Scalar: """Return the resultant value.""" +class AggregateUDFExportable(Protocol): + """Type hint for object that has __datafusion_aggregate_udf__ PyCapsule.""" + + def __datafusion_aggregate_udf__(self) -> object: ... # noqa: D105 + + class AggregateUDF: """Class for performing scalar user-defined functions (UDF). @@ -263,6 +304,9 @@ def __init__( See :py:func:`udaf` for a convenience function and argument descriptions. """ + if hasattr(accumulator, "__datafusion_aggregate_udf__"): + self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator) + return self._udaf = df_internal.AggregateUDF( name, accumulator, @@ -307,7 +351,7 @@ def udaf( ) -> AggregateUDF: ... @staticmethod - def udaf(*args: Any, **kwargs: Any): # noqa: D417 + def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901 """Create a new User-Defined Aggregate Function (UDAF). This class allows you to define an aggregate function that can be used in @@ -364,6 +408,10 @@ def udf4() -> Summarize: Args: accum: The accumulator python function. Only needed when calling as a function. Skip this argument when using ``udaf`` as a decorator. + If you have a Rust backed AggregateUDF within a PyCapsule, you can + pass this parameter and ignore the rest. They will be determined + directly from the underlying function. See the online documentation + for more information. input_types: The data types of the arguments to ``accum``. return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. @@ -422,12 +470,32 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator + if hasattr(args[0], "__datafusion_aggregate_udf__"): + return AggregateUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return _function(*args, **kwargs) # Case 2: Used as a decorator with parameters return _decorator(*args, **kwargs) + @staticmethod + def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF: + """Create an Aggregate UDF from AggregateUDF PyCapsule object. + + This function will instantiate a Aggregate UDF that uses a DataFusion + AggregateUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return AggregateUDF( + name=name, + accumulator=func, + input_types=None, + return_type=None, + state_type=None, + volatility=None, + ) + class WindowEvaluator: """Evaluator class for user-defined window functions (UDWF). @@ -588,6 +656,12 @@ def include_rank(self) -> bool: return False +class WindowUDFExportable(Protocol): + """Type hint for object that has __datafusion_window_udf__ PyCapsule.""" + + def __datafusion_window_udf__(self) -> object: ... # noqa: D105 + + class WindowUDF: """Class for performing window user-defined functions (UDF). @@ -608,6 +682,9 @@ def __init__( See :py:func:`udwf` for a convenience function and argument descriptions. """ + if hasattr(func, "__datafusion_window_udf__"): + self._udwf = df_internal.WindowUDF.from_pycapsule(func) + return self._udwf = df_internal.WindowUDF( name, func, input_types, return_type, str(volatility) ) @@ -683,7 +760,10 @@ def biased_numbers() -> BiasedNumbers: Args: func: Only needed when calling as a function. Skip this argument when - using ``udwf`` as a decorator. + using ``udwf`` as a decorator. If you have a Rust backed WindowUDF + within a PyCapsule, you can pass this parameter and ignore the rest. + They will be determined directly from the underlying function. See + the online documentation for more information. input_types: The data types of the arguments. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. @@ -692,6 +772,9 @@ def biased_numbers() -> BiasedNumbers: Returns: A user-defined window function that can be used in window function calls. """ + if hasattr(args[0], "__datafusion_window_udf__"): + return WindowUDF.from_pycapsule(args[0]) + if args and callable(args[0]): # Case 1: Used as a function, require the first parameter to be callable return WindowUDF._create_window_udf(*args, **kwargs) @@ -759,6 +842,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return decorator + @staticmethod + def from_pycapsule(func: WindowUDFExportable) -> WindowUDF: + """Create a Window UDF from WindowUDF PyCapsule object. + + This function will instantiate a Window UDF that uses a DataFusion + WindowUDF that is exported via the FFI bindings. + """ + name = str(func.__class__) + return WindowUDF( + name=name, + func=func, + input_types=None, + return_type=None, + volatility=None, + ) + class TableFunction: """Class for performing user-defined table functions (UDTF). diff --git a/src/functions.rs b/src/functions.rs index b40500b8b..eeef48385 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -682,7 +682,7 @@ pub fn approx_percentile_cont_with_weight( add_builder_fns_to_aggregate(agg_fn, None, filter, None, None) } -// We handle first_value explicitly because the signature expects an order_by +// We handle last_value explicitly because the signature expects an order_by // https://github.com/apache/datafusion/issues/12376 #[pyfunction] #[pyo3(signature = (expr, distinct=None, filter=None, order_by=None, null_treatment=None))] diff --git a/src/udaf.rs b/src/udaf.rs index 34a9cd51d..78f4e2b0c 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -19,6 +19,10 @@ use std::sync::Arc; use pyo3::{prelude::*, types::PyTuple}; +use crate::common::data_type::PyScalarValue; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; +use crate::expr::PyExpr; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::array::{Array, ArrayRef}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; @@ -27,11 +31,8 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF, }; - -use crate::common::data_type::PyScalarValue; -use crate::errors::to_datafusion_err; -use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF}; +use pyo3::types::PyCapsule; #[derive(Debug)] struct RustAccumulator { @@ -183,6 +184,26 @@ impl PyAggregateUDF { Ok(Self { function }) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_aggregate_udf__")? { + let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_aggregate_udf")?; + + let udaf = unsafe { capsule.reference::() }; + let udaf: ForeignAggregateUDF = udaf.try_into()?; + + Ok(Self { + function: udaf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(), + )) + } + } + /// creates a new PyExpr with the call of the udf #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { diff --git a/src/udf.rs b/src/udf.rs index 574c9d7b5..de1e3f18c 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -17,6 +17,8 @@ use std::sync::Arc; +use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF}; +use pyo3::types::PyCapsule; use pyo3::{prelude::*, types::PyTuple}; use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef}; @@ -29,8 +31,9 @@ use datafusion::logical_expr::ScalarUDF; use datafusion::logical_expr::{create_udf, ColumnarValue}; use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use crate::utils::{parse_volatility, validate_pycapsule}; /// Create a Rust callable function from a python function that expects pyarrow arrays fn pyarrow_function_to_rust( @@ -105,6 +108,26 @@ impl PyScalarUDF { Ok(Self { function }) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_scalar_udf__")? { + let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_scalar_udf")?; + + let udf = unsafe { capsule.reference::() }; + let udf: ForeignScalarUDF = udf.try_into()?; + + Ok(Self { + function: udf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_scalar_udf__ does not exist on ScalarUDF object.".to_string(), + )) + } + } + /// creates a new PyExpr with the call of the udf #[pyo3(signature = (*args))] fn __call__(&self, args: Vec) -> PyResult { diff --git a/src/udwf.rs b/src/udwf.rs index a0c8cc59a..4fb98916b 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -27,16 +27,17 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use crate::common::data_type::PyScalarValue; -use crate::errors::to_datafusion_err; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::parse_volatility; +use crate::utils::{parse_volatility, validate_pycapsule}; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, }; -use pyo3::types::{PyList, PyTuple}; +use datafusion_ffi::udwf::{FFI_WindowUDF, ForeignWindowUDF}; +use pyo3::types::{PyCapsule, PyList, PyTuple}; #[derive(Debug)] struct RustPartitionEvaluator { @@ -245,6 +246,26 @@ impl PyWindowUDF { Ok(self.function.call(args).into()) } + #[staticmethod] + pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult { + if func.hasattr("__datafusion_window_udf__")? { + let capsule = func.getattr("__datafusion_window_udf__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_window_udf")?; + + let udwf = unsafe { capsule.reference::() }; + let udwf: ForeignWindowUDF = udwf.try_into()?; + + Ok(Self { + function: udwf.into(), + }) + } else { + Err(crate::errors::PyDataFusionError::Common( + "__datafusion_window_udf__ does not exist on WindowUDF object.".to_string(), + )) + } + } + fn __repr__(&self) -> PyResult { Ok(format!("WindowUDF({})", self.function.name())) } From 39b5971c0d615d6047e75dc200fc31ae9c427397 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Jun 2025 15:27:57 -0400 Subject: [PATCH 3/3] Add default in memory options for adding schema and catalogs --- python/datafusion/catalog.py | 5 +++++ python/datafusion/context.py | 9 +++++++++ src/catalog.rs | 11 +++++++++++ src/context.rs | 13 ++++++++++++- 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/datafusion/catalog.py b/python/datafusion/catalog.py index bebd38161..5f1a317f6 100644 --- a/python/datafusion/catalog.py +++ b/python/datafusion/catalog.py @@ -73,6 +73,11 @@ def database(self, name: str = "public") -> Schema: """Returns the database with the given ``name`` from this catalog.""" return self.schema(name) + def new_in_memory_schema(self, name: str) -> Schema: + """Create a new schema in this catalog using an in-memory provider.""" + self.catalog.new_in_memory_schema(name) + return self.schema(name) + def register_schema(self, name, schema) -> Schema | None: """Register a schema with this catalog.""" return self.catalog.register_schema(name, schema) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 4dc9f42f0..b85f46f8e 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -755,6 +755,15 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def catalog_names(self) -> set[str]: + """Returns the list of catalogs in this context.""" + return self.ctx.catalog_names() + + def new_in_memory_catalog(self, name: str) -> Catalog: + """Create a new catalog in this context using an in-memory provider.""" + self.ctx.new_in_memory_catalog(name) + return self.catalog(name) + def register_catalog_provider( self, name: str, provider: CatalogProviderExportable ) -> None: diff --git a/src/catalog.rs b/src/catalog.rs index ba96ce471..9a24f2d44 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -19,6 +19,7 @@ use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::utils::{validate_pycapsule, wait_for_future}; use async_trait::async_trait; +use datafusion::catalog::MemorySchemaProvider; use datafusion::common::DataFusionError; use datafusion::{ arrow::pyarrow::ToPyArrow, @@ -105,6 +106,16 @@ impl PyCatalog { }) } + fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> { + let schema = Arc::new(MemorySchemaProvider::new()) as Arc; + let _ = self + .catalog + .register_schema(name, schema) + .map_err(py_datafusion_err)?; + + Ok(()) + } + fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> { let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? { let capsule = schema_provider diff --git a/src/context.rs b/src/context.rs index c92c84acc..0b1cfc189 100644 --- a/src/context.rs +++ b/src/context.rs @@ -49,7 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::catalog::CatalogProvider; +use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider}; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -612,6 +612,13 @@ impl PySessionContext { Ok(()) } + pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> { + let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc; + let _ = self.ctx.register_catalog(name, catalog); + + Ok(()) + } + pub fn register_catalog_provider( &mut self, name: &str, @@ -870,6 +877,10 @@ impl PySessionContext { }) } + pub fn catalog_names(&self) -> HashSet { + self.ctx.catalog_names().into_iter().collect() + } + pub fn tables(&self) -> HashSet { self.ctx .catalog_names()