Skip to content

Commit 0ff5e90

Browse files
committed
feat: enhance TableProvider to support Send trait across multiple modules
1 parent 207a3b5 commit 0ff5e90

File tree

5 files changed

+58
-19
lines changed

5 files changed

+58
-19
lines changed

src/catalog.rs

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ pub struct PySchema {
5252
#[pyclass(name = "RawTable", module = "datafusion.catalog", subclass)]
5353
#[derive(Clone)]
5454
pub struct PyTable {
55-
pub table: Arc<dyn TableProvider>,
55+
pub table: Arc<dyn TableProvider + Send>,
5656
}
5757

5858
impl From<Arc<dyn CatalogProvider>> for PyCatalog {
@@ -68,11 +68,11 @@ impl From<Arc<dyn SchemaProvider>> for PySchema {
6868
}
6969

7070
impl PyTable {
71-
pub fn new(table: Arc<dyn TableProvider>) -> Self {
71+
pub fn new(table: Arc<dyn TableProvider + Send>) -> Self {
7272
Self { table }
7373
}
7474

75-
pub fn table(&self) -> Arc<dyn TableProvider> {
75+
pub fn table(&self) -> Arc<dyn TableProvider + Send> {
7676
self.table.clone()
7777
}
7878
}
@@ -206,7 +206,7 @@ impl PySchema {
206206

207207
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
208208
let provider: ForeignTableProvider = provider.into();
209-
Arc::new(provider) as Arc<dyn TableProvider>
209+
Arc::new(provider) as Arc<dyn TableProvider + Send>
210210
} else {
211211
match table_provider.extract::<PyTable>() {
212212
Ok(py_table) => py_table.table,
@@ -215,7 +215,7 @@ impl PySchema {
215215
Err(_) => {
216216
let py = table_provider.py();
217217
let provider = Dataset::new(&table_provider, py)?;
218-
Arc::new(provider) as Arc<dyn TableProvider>
218+
Arc::new(provider) as Arc<dyn TableProvider + Send>
219219
}
220220
},
221221
}
@@ -298,7 +298,7 @@ impl RustWrappedPySchemaProvider {
298298
}
299299
}
300300

301-
fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider>>> {
301+
fn table_inner(&self, name: &str) -> PyResult<Option<Arc<dyn TableProvider + Send>>> {
302302
Python::with_gil(|py| {
303303
let provider = self.schema_provider.bind(py);
304304
let py_table_method = provider.getattr("table")?;
@@ -316,7 +316,7 @@ impl RustWrappedPySchemaProvider {
316316
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
317317
let provider: ForeignTableProvider = provider.into();
318318

319-
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider>))
319+
Ok(Some(Arc::new(provider) as Arc<dyn TableProvider + Send>))
320320
} else {
321321
if let Ok(inner_table) = py_table.getattr("table") {
322322
if let Ok(inner_table) = inner_table.extract::<PyTable>() {
@@ -332,7 +332,7 @@ impl RustWrappedPySchemaProvider {
332332
Ok(py_table) => Ok(Some(py_table.table)),
333333
Err(_) => {
334334
let ds = Dataset::new(&py_table, py).map_err(py_datafusion_err)?;
335-
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider>))
335+
Ok(Some(Arc::new(ds) as Arc<dyn TableProvider + Send>))
336336
}
337337
}
338338
}
@@ -368,15 +368,32 @@ impl SchemaProvider for RustWrappedPySchemaProvider {
368368
&self,
369369
name: &str,
370370
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
371-
self.table_inner(name).map_err(to_datafusion_err)
371+
// Convert from our internal Send type to the trait expected type
372+
match self.table_inner(name).map_err(to_datafusion_err)? {
373+
Some(table) => {
374+
// Safe conversion: we're widening the bounds (removing Send)
375+
let raw = Arc::into_raw(table);
376+
let wide: *const dyn TableProvider = raw as *const _;
377+
let arc = unsafe { Arc::from_raw(wide) };
378+
Ok(Some(arc))
379+
}
380+
None => Ok(None),
381+
}
372382
}
373383

374384
fn register_table(
375385
&self,
376386
name: String,
377387
table: Arc<dyn TableProvider>,
378388
) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
379-
let py_table = PyTable::new(table);
389+
// Convert from trait type to our internal Send type
390+
let send_table = {
391+
let raw = Arc::into_raw(table);
392+
let send: *const (dyn TableProvider + Send) = raw as *const _;
393+
unsafe { Arc::from_raw(send) }
394+
};
395+
396+
let py_table = PyTable::new(send_table);
380397
Python::with_gil(|py| {
381398
let provider = self.schema_provider.bind(py);
382399
let _ = provider
@@ -405,7 +422,14 @@ impl SchemaProvider for RustWrappedPySchemaProvider {
405422
// If we can turn this table provider into a `Dataset`, return it.
406423
// Otherwise, return None.
407424
let dataset = match Dataset::new(&table, py) {
408-
Ok(dataset) => Some(Arc::new(dataset) as Arc<dyn TableProvider>),
425+
Ok(dataset) => {
426+
// Convert from our internal Send type to trait expected type
427+
let send_table = Arc::new(dataset) as Arc<dyn TableProvider + Send>;
428+
let raw = Arc::into_raw(send_table);
429+
let wide: *const dyn TableProvider = raw as *const _;
430+
let arc = unsafe { Arc::from_raw(wide) };
431+
Some(arc)
432+
}
409433
Err(_) => None,
410434
};
411435

src/context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ impl PySessionContext {
617617

618618
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
619619
let provider: ForeignTableProvider = provider.into();
620-
Arc::new(provider) as Arc<dyn TableProvider>
620+
Arc::new(provider) as Arc<dyn TableProvider + Send>
621621
} else if let Ok(py_table) = table_provider.extract::<PyTable>() {
622622
py_table.table()
623623
} else if let Ok(py_provider) = table_provider.extract::<PyTableProvider>() {
@@ -858,7 +858,7 @@ impl PySessionContext {
858858
dataset: &Bound<'_, PyAny>,
859859
py: Python,
860860
) -> PyDataFusionResult<()> {
861-
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, py)?);
861+
let table: Arc<dyn TableProvider + Send> = Arc::new(Dataset::new(dataset, py)?);
862862

863863
self.ctx.register_table(name, table)?;
864864

src/dataframe.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ use crate::expr::sort_expr::to_sort_expressions;
4848
use crate::physical_plan::PyExecutionPlan;
4949
use crate::record_batch::PyRecordBatchStream;
5050
use crate::sql::logical::PyLogicalPlan;
51-
use crate::table::PyTableProvider;
51+
pub use crate::table::PyTableProvider;
5252
use crate::utils::{
5353
get_tokio_runtime, is_ipython_env, py_obj_to_scalar_value, validate_pycapsule, wait_for_future,
5454
};
@@ -268,7 +268,7 @@ impl PyDataFrame {
268268
}
269269
}
270270

271-
pub(crate) fn into_view_provider(&self) -> Arc<dyn TableProvider> {
271+
pub(crate) fn into_view_provider(&self) -> Arc<dyn TableProvider + Send> {
272272
self.df.as_ref().clone().into_view()
273273
}
274274

src/table.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,28 @@ use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2323
use pyo3::prelude::*;
2424
use pyo3::types::PyCapsule;
2525

26+
use crate::catalog::PyTable;
2627
use crate::dataframe::PyDataFrame;
2728
use crate::errors::{py_datafusion_err, PyDataFusionResult};
2829
use crate::utils::{get_tokio_runtime, validate_pycapsule};
2930

3031
/// Represents a table provider that can be registered with DataFusion
3132
#[pyclass(name = "TableProvider", module = "datafusion")]
33+
#[derive(Clone)]
3234
pub struct PyTableProvider {
3335
pub(crate) provider: Arc<dyn TableProvider + Send>,
3436
}
3537

3638
impl PyTableProvider {
37-
pub(crate) fn new(provider: Arc<dyn TableProvider>) -> Self {
39+
pub(crate) fn new(provider: Arc<dyn TableProvider + Send>) -> Self {
3840
Self { provider }
3941
}
42+
43+
/// Return a `PyTable` wrapper around this provider so callers can call
44+
/// `as_table().table()` to get the underlying `Arc<dyn TableProvider + Send>`.
45+
pub fn as_table(&self) -> PyTable {
46+
PyTable::new(Arc::clone(&self.provider))
47+
}
4048
}
4149

4250
#[pymethods]

src/udtf.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ impl PyTableFunction {
8787
fn call_python_table_function(
8888
func: &Arc<PyObject>,
8989
args: &[Expr],
90-
) -> DataFusionResult<Arc<dyn TableProvider>> {
90+
) -> DataFusionResult<Arc<dyn TableProvider + Send>> {
9191
let args = args
9292
.iter()
9393
.map(|arg| PyExpr::from(arg.clone()))
@@ -107,7 +107,7 @@ fn call_python_table_function(
107107
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
108108
let provider: ForeignTableProvider = provider.into();
109109

110-
Ok(Arc::new(provider) as Arc<dyn TableProvider>)
110+
Ok(Arc::new(provider) as Arc<dyn TableProvider + Send>)
111111
} else {
112112
Err(PyNotImplementedError::new_err(
113113
"__datafusion_table_provider__ does not exist on Table Provider object.",
@@ -121,7 +121,14 @@ impl TableFunctionImpl for PyTableFunction {
121121
fn call(&self, args: &[Expr]) -> DataFusionResult<Arc<dyn TableProvider>> {
122122
match &self.inner {
123123
PyTableFunctionInner::FFIFunction(func) => func.call(args),
124-
PyTableFunctionInner::PythonFunction(obj) => call_python_table_function(obj, args),
124+
PyTableFunctionInner::PythonFunction(obj) => {
125+
let send_result = call_python_table_function(obj, args)?;
126+
// Convert from our Send type to the trait expected type
127+
let raw = Arc::into_raw(send_result);
128+
let wide: *const dyn TableProvider = raw as *const _;
129+
let arc = unsafe { Arc::from_raw(wide) };
130+
Ok(arc)
131+
}
125132
}
126133
}
127134
}

0 commit comments

Comments
 (0)