Skip to content

Commit 91b90f4

Browse files
committed
Harden conversion of datafusion_table_provider capsules
Enhance the safety of capsule conversions by verifying destructors, pointer alignment, and necessary function pointers pre-instantiation of ForeignTableProvider. Introduced RawTable.from_table_provider_capsule for reusability in capsule normalization within SessionContext.read_table. Added a regression test to ensure invalid capsules raise a ValueError instead of causing segmentation faults.
1 parent b9bf5c5 commit 91b90f4

File tree

5 files changed

+120
-8
lines changed

5 files changed

+120
-8
lines changed

python/datafusion/catalog.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,14 @@ def from_dataset(dataset: pa.dataset.Dataset) -> Table:
178178
"""Turn a :mod:`pyarrow.dataset` ``Dataset`` into a :class:`Table`."""
179179
return Table(dataset)
180180

181+
@staticmethod
182+
def from_table_provider_capsule(capsule: object) -> Table:
183+
"""Wrap a validated table provider :class:`PyCapsule` as a :class:`Table`."""
184+
185+
return Table(
186+
df_internal.catalog.RawTable.from_table_provider_capsule(capsule)
187+
)
188+
181189
@property
182190
def schema(self) -> pa.Schema:
183191
"""Returns the schema associated with this table."""

python/datafusion/context.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import pyarrow as pa
3131

32-
from datafusion.catalog import Catalog
32+
from datafusion.catalog import Catalog, Table
3333
from datafusion.dataframe import DataFrame
3434
from datafusion.expr import sort_list_to_raw_sort_list
3535
from datafusion.record_batch import RecordBatchStream
@@ -1181,6 +1181,16 @@ def read_table(
11811181
self, table: Table | TableProviderExportable | DataFrame | pa.dataset.Dataset
11821182
) -> DataFrame:
11831183
"""Creates a :py:class:`~datafusion.dataframe.DataFrame` from a table."""
1184+
if not isinstance(table, Table):
1185+
capsule_obj: object | None = None
1186+
if hasattr(table, "__datafusion_table_provider__"):
1187+
capsule_obj = table.__datafusion_table_provider__()
1188+
elif table.__class__.__name__ == "PyCapsule":
1189+
capsule_obj = table
1190+
1191+
if capsule_obj is not None:
1192+
table = Table.from_table_provider_capsule(capsule_obj)
1193+
11841194
return DataFrame(self.ctx.read_table(table))
11851195

11861196
def execute(self, plan: ExecutionPlan, partitions: int) -> RecordBatchStream:

python/tests/test_context.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import ctypes
1718
import datetime as dt
1819
import gzip
1920
import pathlib
@@ -341,6 +342,35 @@ def test_read_table_from_dataset(ctx):
341342
assert result[0].column(1) == pa.array([4, 5, 6])
342343

343344

345+
def test_read_table_rejects_invalid_table_provider_capsule(ctx):
346+
class CapsuleContainer:
347+
def __init__(self) -> None:
348+
self._buffer = ctypes.create_string_buffer(b"x")
349+
350+
def __datafusion_table_provider__(self) -> object:
351+
pycapsule_new = ctypes.pythonapi.PyCapsule_New
352+
pycapsule_new.restype = ctypes.py_object
353+
pycapsule_new.argtypes = [
354+
ctypes.c_void_p,
355+
ctypes.c_char_p,
356+
ctypes.c_void_p,
357+
]
358+
dummy_ptr = ctypes.cast(self._buffer, ctypes.c_void_p)
359+
return pycapsule_new(
360+
dummy_ptr, b"datafusion_table_provider", None
361+
)
362+
363+
container = CapsuleContainer()
364+
365+
with pytest.raises(ValueError, match="missing a destructor"):
366+
ctx.read_table(container)
367+
368+
with pytest.raises(ValueError, match="missing a destructor"):
369+
Table.from_table_provider_capsule(
370+
container.__datafusion_table_provider__()
371+
)
372+
373+
344374
def test_deregister_table(ctx, database):
345375
default = ctx.catalog()
346376
public = default.schema("public")

src/table.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
use arrow::pyarrow::ToPyArrow;
1919
use datafusion::datasource::{TableProvider, TableType};
2020
use pyo3::prelude::*;
21+
use pyo3::types::PyCapsule;
2122
use std::sync::Arc;
2223

2324
use crate::dataframe::PyDataFrame;
2425
use crate::dataset::Dataset;
25-
use crate::utils::table_provider_from_pycapsule;
26+
use crate::errors::py_datafusion_err;
27+
use crate::utils::{table_provider_from_capsule, table_provider_from_pycapsule};
2628

2729
/// This struct is used as a common method for all TableProviders,
2830
/// whether they refer to an FFI provider, an internally known
@@ -77,6 +79,13 @@ impl PyTable {
7779
}
7880
}
7981

82+
#[staticmethod]
83+
pub fn from_table_provider_capsule(capsule: &Bound<'_, PyAny>) -> PyResult<Self> {
84+
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
85+
let provider = table_provider_from_capsule(&capsule)?;
86+
Ok(PyTable::from(provider))
87+
}
88+
8089
/// Get a reference to the schema for this table
8190
#[getter]
8291
fn schema(&self, py: Python) -> PyResult<PyObject> {

src/utils.rs

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use datafusion::{
2727
};
2828
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2929
use pyo3::prelude::*;
30-
use pyo3::{exceptions::PyValueError, types::PyCapsule};
30+
use pyo3::{exceptions::PyValueError, ffi, types::PyCapsule};
3131
use std::{
3232
future::Future,
3333
sync::{Arc, OnceLock},
@@ -124,18 +124,73 @@ pub(crate) fn validate_pycapsule(capsule: &Bound<PyCapsule>, name: &str) -> PyRe
124124
Ok(())
125125
}
126126

127+
fn ensure_capsule_has_destructor(capsule: &Bound<PyCapsule>) -> PyResult<()> {
128+
if unsafe { ffi::PyCapsule_GetDestructor(capsule.as_ptr()) }.is_none() {
129+
return Err(PyValueError::new_err(
130+
"Table provider capsule is missing a destructor; ensure it was created via datafusion_ffi's helpers.",
131+
));
132+
}
133+
134+
Ok(())
135+
}
136+
137+
fn ensure_capsule_pointer(capsule: &Bound<PyCapsule>) -> PyResult<()> {
138+
let ptr = capsule.pointer();
139+
if ptr.is_null() {
140+
return Err(PyValueError::new_err(
141+
"Table provider capsule contained a null pointer.",
142+
));
143+
}
144+
145+
if (ptr as usize) % std::mem::align_of::<FFI_TableProvider>() != 0 {
146+
return Err(PyValueError::new_err(
147+
"Table provider capsule pointer was not aligned for FFI_TableProvider.",
148+
));
149+
}
150+
151+
Ok(())
152+
}
153+
154+
fn validate_foreign_table_provider(provider: &FFI_TableProvider) -> PyResult<()> {
155+
if provider.schema as usize == 0
156+
|| provider.scan as usize == 0
157+
|| provider.table_type as usize == 0
158+
|| provider.clone as usize == 0
159+
|| provider.release as usize == 0
160+
|| provider.version as usize == 0
161+
|| provider.private_data.is_null()
162+
{
163+
return Err(PyValueError::new_err(
164+
"Table provider capsule is missing required function pointers.",
165+
));
166+
}
167+
168+
Ok(())
169+
}
170+
171+
pub(crate) fn table_provider_from_capsule(
172+
capsule: &Bound<PyCapsule>,
173+
) -> PyResult<Arc<dyn TableProvider>> {
174+
validate_pycapsule(capsule, "datafusion_table_provider")?;
175+
ensure_capsule_has_destructor(capsule)?;
176+
ensure_capsule_pointer(capsule)?;
177+
178+
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
179+
validate_foreign_table_provider(provider)?;
180+
let provider: ForeignTableProvider = provider.into();
181+
182+
Ok(Arc::new(provider))
183+
}
184+
127185
pub(crate) fn table_provider_from_pycapsule(
128186
obj: &Bound<PyAny>,
129187
) -> PyResult<Option<Arc<dyn TableProvider>>> {
130188
if obj.hasattr("__datafusion_table_provider__")? {
131189
let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?;
132190
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
133-
validate_pycapsule(capsule, "datafusion_table_provider")?;
134-
135-
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
136-
let provider: ForeignTableProvider = provider.into();
191+
let provider = table_provider_from_capsule(&capsule)?;
137192

138-
Ok(Some(Arc::new(provider)))
193+
Ok(Some(provider))
139194
} else {
140195
Ok(None)
141196
}

0 commit comments

Comments
 (0)