Skip to content

Commit 269d9bb

Browse files
committed
Enhance Table initialization to support DataFrame conversion and update PySessionContext methods to be immutable
1 parent 4260835 commit 269d9bb

File tree

4 files changed

+31
-6
lines changed

4 files changed

+31
-6
lines changed

python/datafusion/catalog.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,16 @@ class Table:
166166

167167
def __init__(
168168
self,
169-
table: _InternalRawTable | _InternalTableProvider | Table,
169+
table: _InternalRawTable | _InternalTableProvider | Table | Any,
170170
) -> None:
171-
"""Wrap a low level table or table provider."""
171+
"""Wrap a low level table, table provider, or convertibles like DataFrame."""
172172
if isinstance(table, Table):
173173
table = table.table
174+
else:
175+
from datafusion.dataframe import DataFrame as DataFrameWrapper
176+
177+
if isinstance(table, DataFrameWrapper):
178+
table = table.df.into_view()
174179

175180
if not isinstance(table, (_InternalRawTable, _InternalTableProvider)):
176181
raise TypeError(EXPECTED_PROVIDER_MSG)

python/tests/test_catalog.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ def test_schema_register_table_with_dataframe_errors(ctx: SessionContext):
196196
assert str(exc_info.value) == EXPECTED_PROVIDER_MSG
197197

198198

199+
def test_table_wraps_dataframe(ctx: SessionContext):
200+
df = ctx.sql("SELECT 1 AS value")
201+
202+
table = Table(df)
203+
ctx.register_table("df_table", table)
204+
205+
try:
206+
batches = ctx.sql("SELECT value FROM df_table").collect()
207+
assert len(batches) == 1
208+
assert batches[0].column(0).to_pylist() == [1]
209+
finally:
210+
ctx.deregister_table("df_table")
211+
212+
199213
def test_in_end_to_end_python_providers(ctx: SessionContext):
200214
"""Test registering all python providers and running a query against them."""
201215

src/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -429,15 +429,15 @@ impl PySessionContext {
429429
}
430430

431431
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
432-
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
432+
pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
433433
let result = self.ctx.sql(query);
434434
let df = wait_for_future(py, result)??;
435435
Ok(PyDataFrame::new(df))
436436
}
437437

438438
#[pyo3(signature = (query, options=None))]
439439
pub fn sql_with_options(
440-
&mut self,
440+
&self,
441441
query: &str,
442442
options: Option<PySQLOptions>,
443443
py: Python,

src/utils.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ where
8282
const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
8383

8484
py.allow_threads(|| {
85-
runtime.block_on(async {
85+
let wait_future = || async {
8686
tokio::pin!(fut);
8787
loop {
8888
tokio::select! {
@@ -92,7 +92,13 @@ where
9292
}
9393
}
9494
}
95-
})
95+
};
96+
97+
if tokio::runtime::Handle::try_current().is_ok() {
98+
tokio::task::block_in_place(|| runtime.block_on(wait_future()))
99+
} else {
100+
runtime.block_on(wait_future())
101+
}
96102
})
97103
}
98104

0 commit comments

Comments
 (0)