Skip to content

Commit b9041ba

Browse files
committed
test: add unit test for refreshing reassigned pandas DataFrame in SQL context
1 parent 1764a57 commit b9041ba

File tree

2 files changed

+73
-15
lines changed

2 files changed

+73
-15
lines changed

python/datafusion/context.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import inspect
2323
import re
2424
import warnings
25+
import weakref
2526
from typing import TYPE_CHECKING, Any, Protocol
2627

2728
try:
@@ -563,6 +564,9 @@ def __init__(
563564
auto_python_table_lookup = getattr(config, "_python_table_lookup", False)
564565

565566
self._auto_python_table_lookup = bool(auto_python_table_lookup)
567+
self._python_table_bindings: dict[
568+
str, tuple[weakref.ReferenceType[Any] | None, int]
569+
] = {}
566570

567571
def __repr__(self) -> str:
568572
"""Print a string representation of the Session Context."""
@@ -592,6 +596,9 @@ def enable_url_table(self) -> SessionContext:
592596
obj._auto_python_table_lookup = getattr(
593597
self, "_auto_python_table_lookup", False
594598
)
599+
obj._python_table_bindings = getattr(
600+
self, "_python_table_bindings", {}
601+
).copy()
595602
return obj
596603

597604
def set_python_table_lookup(self, enabled: bool = True) -> SessionContext:
@@ -700,10 +707,13 @@ def _execute_sql() -> DataFrame:
700707

701708
auto_lookup_enabled = getattr(self, "_auto_python_table_lookup", False)
702709

710+
if auto_lookup_enabled:
711+
self._refresh_python_table_bindings()
712+
703713
while True:
704714
try:
705715
return _execute_sql()
706-
except Exception as err:
716+
except Exception as err: # noqa: PERF203
707717
if not auto_lookup_enabled:
708718
raise
709719

@@ -815,34 +825,60 @@ def _match(mapping: dict[str, Any]) -> Any | None:
815825
del frame
816826
return None
817827

828+
def _refresh_python_table_bindings(self) -> None:
829+
bindings = getattr(self, "_python_table_bindings", {})
830+
for table_name, (obj_ref, cached_id) in list(bindings.items()):
831+
cached_obj = obj_ref() if obj_ref is not None else None
832+
current_obj = self._lookup_python_object(table_name)
833+
weakref_dead = obj_ref is not None and cached_obj is None
834+
id_mismatch = current_obj is not None and id(current_obj) != cached_id
835+
836+
if not (weakref_dead or id_mismatch):
837+
continue
838+
839+
self.deregister_table(table_name)
840+
841+
if current_obj is None:
842+
bindings.pop(table_name, None)
843+
continue
844+
845+
if self._register_python_object(table_name, current_obj):
846+
continue
847+
848+
bindings.pop(table_name, None)
849+
818850
def _register_python_object(self, name: str, obj: Any) -> bool:
851+
registered = False
852+
819853
if isinstance(obj, DataFrame):
820854
self.register_view(name, obj)
821-
return True
822-
823-
if (
855+
registered = True
856+
elif (
824857
obj.__class__.__module__.startswith("polars.")
825858
and obj.__class__.__name__ == "DataFrame"
826859
):
827860
self.from_polars(obj, name=name)
828-
return True
829-
830-
if (
861+
registered = True
862+
elif (
831863
obj.__class__.__module__.startswith("pandas.")
832864
and obj.__class__.__name__ == "DataFrame"
833865
):
834866
self.from_pandas(obj, name=name)
835-
return True
836-
837-
if isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)):
867+
registered = True
868+
elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)) or (
869+
hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__")
870+
):
838871
self.from_arrow(obj, name=name)
839-
return True
872+
registered = True
840873

841-
if hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__"):
842-
self.from_arrow(obj, name=name)
843-
return True
874+
if registered:
875+
try:
876+
reference: weakref.ReferenceType[Any] | None = weakref.ref(obj)
877+
except TypeError:
878+
reference = None
879+
self._python_table_bindings[name] = (reference, id(obj))
844880

845-
return False
881+
return registered
846882

847883
def create_dataframe(
848884
self,
@@ -981,6 +1017,7 @@ def register_table(self, name: str, table: Table) -> None:
9811017
def deregister_table(self, name: str) -> None:
9821018
"""Remove a table from the session."""
9831019
self.ctx.deregister_table(name)
1020+
self._python_table_bindings.pop(name, None)
9841021

9851022
def catalog_names(self) -> set[str]:
9861023
"""Returns the list of catalogs in this context."""

python/tests/test_context.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,27 @@ def test_sql_auto_register_pandas_dataframe():
375375
assert pytest.approx(result[0].column(0).to_pylist()[0]) == 2.5
376376

377377

378+
def test_sql_auto_register_refreshes_reassigned_dataframe():
379+
pd = pytest.importorskip("pandas")
380+
381+
ctx = SessionContext(auto_register_python_objects=True)
382+
pandas_df = pd.DataFrame({"value": [1, 2, 3]})
383+
384+
first = ctx.sql(
385+
"SELECT SUM(value) AS total FROM pandas_df",
386+
).collect()
387+
388+
assert first[0].column(0).to_pylist()[0] == 6
389+
390+
pandas_df = pd.DataFrame({"value": [10, 20]}) # noqa: F841
391+
392+
second = ctx.sql(
393+
"SELECT SUM(value) AS total FROM pandas_df",
394+
).collect()
395+
396+
assert second[0].column(0).to_pylist()[0] == 30
397+
398+
378399
def test_sql_auto_register_polars_dataframe():
379400
pl = pytest.importorskip("polars")
380401

0 commit comments

Comments
 (0)