Skip to content

Commit 43d87a6

Browse files
committed
Add support for creating in memory catalog and schema
1 parent 2b300b5 commit 43d87a6

File tree

5 files changed

+72
-35
lines changed

5 files changed

+72
-35
lines changed

python/datafusion/catalog.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def schema_names(self) -> set[str]:
5858
"""Returns the list of schemas in this catalog."""
5959
return self.catalog.schema_names()
6060

61+
@staticmethod
62+
def memory_catalog() -> Catalog:
63+
"""Create an in-memory catalog provider."""
64+
catalog = df_internal.catalog.RawCatalog.memory_catalog()
65+
return Catalog(catalog)
66+
6167
def schema(self, name: str = "public") -> Schema:
6268
"""Returns the database with the given ``name`` from this catalog."""
6369
schema = self.catalog.schema(name)
@@ -73,13 +79,10 @@ def database(self, name: str = "public") -> Schema:
7379
"""Returns the database with the given ``name`` from this catalog."""
7480
return self.schema(name)
7581

76-
def new_in_memory_schema(self, name: str) -> Schema:
77-
"""Create a new schema in this catalog using an in-memory provider."""
78-
self.catalog.new_in_memory_schema(name)
79-
return self.schema(name)
80-
8182
def register_schema(self, name, schema) -> Schema | None:
8283
"""Register a schema with this catalog."""
84+
if isinstance(schema, Schema):
85+
return self.catalog.register_schema(name, schema._raw_schema)
8386
return self.catalog.register_schema(name, schema)
8487

8588
def deregister_schema(self, name: str, cascade: bool = True) -> Schema | None:
@@ -98,6 +101,12 @@ def __repr__(self) -> str:
98101
"""Print a string representation of the schema."""
99102
return self._raw_schema.__repr__()
100103

104+
@staticmethod
105+
def memory_schema() -> Schema:
106+
"""Create an in-memory schema provider."""
107+
schema = df_internal.catalog.RawSchema.memory_schema()
108+
return Schema(schema)
109+
101110
def names(self) -> set[str]:
102111
"""This is an alias for `table_names`."""
103112
return self.table_names()

python/datafusion/context.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -762,16 +762,14 @@ def catalog_names(self) -> set[str]:
762762
"""Returns the list of catalogs in this context."""
763763
return self.ctx.catalog_names()
764764

765-
def new_in_memory_catalog(self, name: str) -> Catalog:
766-
"""Create a new catalog in this context using an in-memory provider."""
767-
self.ctx.new_in_memory_catalog(name)
768-
return self.catalog(name)
769-
770765
def register_catalog_provider(
771-
self, name: str, provider: CatalogProviderExportable
766+
self, name: str, provider: CatalogProviderExportable | Catalog
772767
) -> None:
773768
"""Register a catalog provider."""
774-
self.ctx.register_catalog_provider(name, provider)
769+
if isinstance(provider, Catalog):
770+
self.ctx.register_catalog_provider(name, provider.catalog)
771+
else:
772+
self.ctx.register_catalog_provider(name, provider)
775773

776774
def register_table_provider(
777775
self, name: str, provider: TableProviderExportable

python/tests/test_catalog.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,24 @@ def test_python_catalog_provider(ctx: SessionContext):
106106
assert my_catalog.schema_names() == {"second_schema"}
107107

108108

109+
def test_in_memory_providers(ctx: SessionContext):
110+
catalog = dfn.catalog.Catalog.memory_catalog()
111+
ctx.register_catalog_provider("in_mem_catalog", catalog)
112+
113+
assert ctx.catalog_names() == {"datafusion", "in_mem_catalog"}
114+
115+
schema = dfn.catalog.Schema.memory_schema()
116+
catalog.register_schema("in_mem_schema", schema)
117+
118+
schema.register_table("my_table", create_dataset())
119+
120+
batches = ctx.sql("select * from in_mem_catalog.in_mem_schema.my_table").collect()
121+
122+
assert len(batches) == 1
123+
assert batches[0].column(0) == pa.array([1, 2, 3])
124+
assert batches[0].column(1) == pa.array([4, 5, 6])
125+
126+
109127
def test_python_schema_provider(ctx: SessionContext):
110128
catalog = ctx.catalog()
111129

src/catalog.rs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::dataset::Dataset;
1919
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
2020
use crate::utils::{validate_pycapsule, wait_for_future};
2121
use async_trait::async_trait;
22-
use datafusion::catalog::MemorySchemaProvider;
22+
use datafusion::catalog::{MemoryCatalogProvider, MemorySchemaProvider};
2323
use datafusion::common::DataFusionError;
2424
use datafusion::{
2525
arrow::pyarrow::ToPyArrow,
@@ -37,16 +37,19 @@ use std::collections::HashSet;
3737
use std::sync::Arc;
3838

3939
#[pyclass(name = "RawCatalog", module = "datafusion.catalog", subclass)]
40+
#[derive(Clone)]
4041
pub struct PyCatalog {
4142
pub catalog: Arc<dyn CatalogProvider>,
4243
}
4344

4445
#[pyclass(name = "RawSchema", module = "datafusion.catalog", subclass)]
46+
#[derive(Clone)]
4547
pub struct PySchema {
4648
pub schema: Arc<dyn SchemaProvider>,
4749
}
4850

4951
#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)]
52+
#[derive(Clone)]
5053
pub struct PyTable {
5154
pub table: Arc<dyn TableProvider>,
5255
}
@@ -82,6 +85,13 @@ impl PyCatalog {
8285
catalog_provider.into()
8386
}
8487

88+
#[staticmethod]
89+
fn memory_catalog() -> Self {
90+
let catalog_provider =
91+
Arc::new(MemoryCatalogProvider::default()) as Arc<dyn CatalogProvider>;
92+
catalog_provider.into()
93+
}
94+
8595
fn schema_names(&self) -> HashSet<String> {
8696
self.catalog.schema_names().into_iter().collect()
8797
}
@@ -106,16 +116,6 @@ impl PyCatalog {
106116
})
107117
}
108118

109-
fn new_in_memory_schema(&mut self, name: &str) -> PyResult<()> {
110-
let schema = Arc::new(MemorySchemaProvider::new()) as Arc<dyn SchemaProvider>;
111-
let _ = self
112-
.catalog
113-
.register_schema(name, schema)
114-
.map_err(py_datafusion_err)?;
115-
116-
Ok(())
117-
}
118-
119119
fn register_schema(&self, name: &str, schema_provider: Bound<'_, PyAny>) -> PyResult<()> {
120120
let provider = if schema_provider.hasattr("__datafusion_schema_provider__")? {
121121
let capsule = schema_provider
@@ -128,8 +128,11 @@ impl PyCatalog {
128128
let provider: ForeignSchemaProvider = provider.into();
129129
Arc::new(provider) as Arc<dyn SchemaProvider>
130130
} else {
131-
let provider = RustWrappedPySchemaProvider::new(schema_provider.into());
132-
Arc::new(provider) as Arc<dyn SchemaProvider>
131+
match schema_provider.extract::<PySchema>() {
132+
Ok(py_schema) => py_schema.schema,
133+
Err(_) => Arc::new(RustWrappedPySchemaProvider::new(schema_provider.into()))
134+
as Arc<dyn SchemaProvider>,
135+
}
133136
};
134137

135138
let _ = self
@@ -165,6 +168,12 @@ impl PySchema {
165168
schema_provider.into()
166169
}
167170

171+
#[staticmethod]
172+
fn memory_schema() -> Self {
173+
let schema_provider = Arc::new(MemorySchemaProvider::default()) as Arc<dyn SchemaProvider>;
174+
schema_provider.into()
175+
}
176+
168177
#[getter]
169178
fn table_names(&self) -> HashSet<String> {
170179
self.schema.table_names().into_iter().collect()

src/context.rs

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f
4949
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
5050
use datafusion::arrow::pyarrow::PyArrowType;
5151
use datafusion::arrow::record_batch::RecordBatch;
52-
use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider};
52+
use datafusion::catalog::CatalogProvider;
5353
use datafusion::common::TableReference;
5454
use datafusion::common::{exec_err, ScalarValue};
5555
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
@@ -617,13 +617,6 @@ impl PySessionContext {
617617
Ok(())
618618
}
619619

620-
pub fn new_in_memory_catalog(&mut self, name: &str) -> PyResult<()> {
621-
let catalog = Arc::new(MemoryCatalogProvider::new()) as Arc<dyn CatalogProvider>;
622-
let _ = self.ctx.register_catalog(name, catalog);
623-
624-
Ok(())
625-
}
626-
627620
pub fn register_catalog_provider(
628621
&mut self,
629622
name: &str,
@@ -640,8 +633,18 @@ impl PySessionContext {
640633
let provider: ForeignCatalogProvider = provider.into();
641634
Arc::new(provider) as Arc<dyn CatalogProvider>
642635
} else {
643-
let provider = RustWrappedPyCatalogProvider::new(provider.into());
644-
Arc::new(provider) as Arc<dyn CatalogProvider>
636+
println!("Provider has type {}", provider.get_type());
637+
match provider.extract::<PyCatalog>() {
638+
Ok(py_catalog) => {
639+
println!("registering an existing PyCatalog");
640+
py_catalog.catalog
641+
}
642+
Err(_) => {
643+
println!("registering a rust wrapped catalog provider");
644+
Arc::new(RustWrappedPyCatalogProvider::new(provider.into()))
645+
as Arc<dyn CatalogProvider>
646+
}
647+
}
645648
};
646649

647650
let _ = self.ctx.register_catalog(name, provider);

0 commit comments

Comments
 (0)