Skip to content

Commit 6e449da

Browse files
committed
Enhance table_provider_from_pycapsule handling
Updated the table_provider_from_pycapsule function to directly recognize raw PyCapsule objects, ensuring proper validation before resorting to attribute-based discovery. Additionally, added a reusable helper for creating dummy table-provider capsules and implemented an integration test to confirm that RawTable.from_table_provider_capsule is exposed, while also testing SessionContext.read_table with a raw capsule.
1 parent d784c4d commit 6e449da

File tree

2 files changed

+35
-13
lines changed

2 files changed

+35
-13
lines changed

python/tests/test_context.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@
3333
literal,
3434
)
3535

36+
_PYCAPSULE_NEW = ctypes.pythonapi.PyCapsule_New
37+
_PYCAPSULE_NEW.restype = ctypes.py_object
38+
_PYCAPSULE_NEW.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
39+
40+
41+
def _make_invalid_table_provider_capsule():
42+
backing = ctypes.create_string_buffer(b"x")
43+
capsule = _PYCAPSULE_NEW(
44+
ctypes.cast(backing, ctypes.c_void_p),
45+
b"datafusion_table_provider",
46+
None,
47+
)
48+
return capsule, backing
49+
3650

3751
def test_create_context_no_args():
3852
SessionContext()
@@ -345,20 +359,12 @@ def test_read_table_from_dataset(ctx):
345359
def test_read_table_rejects_invalid_table_provider_capsule(ctx):
346360
class CapsuleContainer:
347361
def __init__(self) -> None:
348-
self._buffer = ctypes.create_string_buffer(b"x")
362+
self._buffers: list[ctypes.Array[ctypes.c_char]] = []
349363

350364
def __datafusion_table_provider__(self) -> object:
351-
pycapsule_new = ctypes.pythonapi.PyCapsule_New
352-
pycapsule_new.restype = ctypes.py_object
353-
pycapsule_new.argtypes = [
354-
ctypes.c_void_p,
355-
ctypes.c_char_p,
356-
ctypes.c_void_p,
357-
]
358-
dummy_ptr = ctypes.cast(self._buffer, ctypes.c_void_p)
359-
return pycapsule_new(
360-
dummy_ptr, b"datafusion_table_provider", None
361-
)
365+
capsule, backing = _make_invalid_table_provider_capsule()
366+
self._buffers.append(backing)
367+
return capsule
362368

363369
container = CapsuleContainer()
364370

@@ -371,6 +377,18 @@ def __datafusion_table_provider__(self) -> object:
371377
)
372378

373379

380+
def test_read_table_with_raw_table_provider_capsule(ctx):
381+
df_internal = pytest.importorskip("datafusion._internal")
382+
assert hasattr(
383+
df_internal.catalog.RawTable, "from_table_provider_capsule"
384+
)
385+
386+
capsule, _backing = _make_invalid_table_provider_capsule()
387+
388+
with pytest.raises(ValueError, match="missing a destructor"):
389+
ctx.read_table(capsule)
390+
391+
374392
def test_deregister_table(ctx, database):
375393
default = ctx.catalog()
376394
public = default.schema("public")

src/utils.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,11 @@ pub(crate) fn table_provider_from_capsule(
185185
pub(crate) fn table_provider_from_pycapsule(
186186
obj: &Bound<PyAny>,
187187
) -> PyResult<Option<Arc<dyn TableProvider>>> {
188-
if obj.hasattr("__datafusion_table_provider__")? {
188+
if let Ok(capsule) = obj.downcast::<PyCapsule>() {
189+
let provider = table_provider_from_capsule(&capsule)?;
190+
191+
Ok(Some(provider))
192+
} else if obj.hasattr("__datafusion_table_provider__")? {
189193
let capsule = obj.getattr("__datafusion_table_provider__")?.call0()?;
190194
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
191195
let provider = table_provider_from_capsule(&capsule)?;

0 commit comments

Comments
 (0)