|
22 | 22 | import inspect |
23 | 23 | import re |
24 | 24 | import warnings |
| 25 | +import weakref |
25 | 26 | from typing import TYPE_CHECKING, Any, Protocol |
26 | 27 |
|
27 | 28 | try: |
@@ -563,6 +564,9 @@ def __init__( |
563 | 564 | auto_python_table_lookup = getattr(config, "_python_table_lookup", False) |
564 | 565 |
|
565 | 566 | self._auto_python_table_lookup = bool(auto_python_table_lookup) |
| 567 | + self._python_table_bindings: dict[ |
| 568 | + str, tuple[weakref.ReferenceType[Any] | None, int] |
| 569 | + ] = {} |
566 | 570 |
|
567 | 571 | def __repr__(self) -> str: |
568 | 572 | """Print a string representation of the Session Context.""" |
@@ -592,6 +596,9 @@ def enable_url_table(self) -> SessionContext: |
592 | 596 | obj._auto_python_table_lookup = getattr( |
593 | 597 | self, "_auto_python_table_lookup", False |
594 | 598 | ) |
| 599 | + obj._python_table_bindings = getattr( |
| 600 | + self, "_python_table_bindings", {} |
| 601 | + ).copy() |
595 | 602 | return obj |
596 | 603 |
|
597 | 604 | def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: |
@@ -700,10 +707,13 @@ def _execute_sql() -> DataFrame: |
700 | 707 |
|
701 | 708 | auto_lookup_enabled = getattr(self, "_auto_python_table_lookup", False) |
702 | 709 |
|
| 710 | + if auto_lookup_enabled: |
| 711 | + self._refresh_python_table_bindings() |
| 712 | + |
703 | 713 | while True: |
704 | 714 | try: |
705 | 715 | return _execute_sql() |
706 | | - except Exception as err: |
| 716 | + except Exception as err: # noqa: PERF203 |
707 | 717 | if not auto_lookup_enabled: |
708 | 718 | raise |
709 | 719 |
|
@@ -815,34 +825,60 @@ def _match(mapping: dict[str, Any]) -> Any | None: |
815 | 825 | del frame |
816 | 826 | return None |
817 | 827 |
|
| 828 | + def _refresh_python_table_bindings(self) -> None: |
| 829 | + bindings = getattr(self, "_python_table_bindings", {}) |
| 830 | + for table_name, (obj_ref, cached_id) in list(bindings.items()): |
| 831 | + cached_obj = obj_ref() if obj_ref is not None else None |
| 832 | + current_obj = self._lookup_python_object(table_name) |
| 833 | + weakref_dead = obj_ref is not None and cached_obj is None |
| 834 | + id_mismatch = current_obj is not None and id(current_obj) != cached_id |
| 835 | + |
| 836 | + if not (weakref_dead or id_mismatch): |
| 837 | + continue |
| 838 | + |
| 839 | + self.deregister_table(table_name) |
| 840 | + |
| 841 | + if current_obj is None: |
| 842 | + bindings.pop(table_name, None) |
| 843 | + continue |
| 844 | + |
| 845 | + if self._register_python_object(table_name, current_obj): |
| 846 | + continue |
| 847 | + |
| 848 | + bindings.pop(table_name, None) |
| 849 | + |
818 | 850 | def _register_python_object(self, name: str, obj: Any) -> bool: |
| 851 | + registered = False |
| 852 | + |
819 | 853 | if isinstance(obj, DataFrame): |
820 | 854 | self.register_view(name, obj) |
821 | | - return True |
822 | | - |
823 | | - if ( |
| 855 | + registered = True |
| 856 | + elif ( |
824 | 857 | obj.__class__.__module__.startswith("polars.") |
825 | 858 | and obj.__class__.__name__ == "DataFrame" |
826 | 859 | ): |
827 | 860 | self.from_polars(obj, name=name) |
828 | | - return True |
829 | | - |
830 | | - if ( |
| 861 | + registered = True |
| 862 | + elif ( |
831 | 863 | obj.__class__.__module__.startswith("pandas.") |
832 | 864 | and obj.__class__.__name__ == "DataFrame" |
833 | 865 | ): |
834 | 866 | self.from_pandas(obj, name=name) |
835 | | - return True |
836 | | - |
837 | | - if isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)): |
| 867 | + registered = True |
| 868 | + elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)) or ( |
| 869 | + hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__") |
| 870 | + ): |
838 | 871 | self.from_arrow(obj, name=name) |
839 | | - return True |
| 872 | + registered = True |
840 | 873 |
|
841 | | - if hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__"): |
842 | | - self.from_arrow(obj, name=name) |
843 | | - return True |
| 874 | + if registered: |
| 875 | + try: |
| 876 | + reference: weakref.ReferenceType[Any] | None = weakref.ref(obj) |
| 877 | + except TypeError: |
| 878 | + reference = None |
| 879 | + self._python_table_bindings[name] = (reference, id(obj)) |
844 | 880 |
|
845 | | - return False |
| 881 | + return registered |
846 | 882 |
|
847 | 883 | def create_dataframe( |
848 | 884 | self, |
@@ -981,6 +1017,7 @@ def register_table(self, name: str, table: Table) -> None: |
981 | 1017 | def deregister_table(self, name: str) -> None: |
982 | 1018 | """Remove a table from the session.""" |
983 | 1019 | self.ctx.deregister_table(name) |
| 1020 | + self._python_table_bindings.pop(name, None) |
984 | 1021 |
|
985 | 1022 | def catalog_names(self) -> set[str]: |
986 | 1023 | """Returns the list of catalogs in this context.""" |
|
0 commit comments