Skip to content

Commit 10f5f47

Browse files
committed
feat: enable implicit table lookup for Python objects in SQL queries
1 parent 2362ae2 commit 10f5f47

File tree

1 file changed

+168
-133
lines changed

1 file changed

+168
-133
lines changed

python/datafusion/context.py

Lines changed: 168 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919

2020
from __future__ import annotations
2121

22+
import importlib
2223
import inspect
2324
import re
2425
import warnings
26+
from functools import cache
2527
from typing import TYPE_CHECKING, Any, Iterator, Protocol
2628

2729
try:
@@ -55,6 +57,15 @@
5557
from datafusion.plan import ExecutionPlan, LogicalPlan
5658

5759

60+
@cache
61+
def _load_optional_module(module_name: str) -> Any | None:
62+
"""Return the module for *module_name* if it can be imported."""
63+
try:
64+
return importlib.import_module(module_name)
65+
except ModuleNotFoundError:
66+
return None
67+
68+
5869
class ArrowStreamExportable(Protocol):
5970
"""Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
6071
@@ -105,6 +116,7 @@ def __init__(self, config_options: dict[str, str] | None = None) -> None:
105116
config_options: Configuration options.
106117
"""
107118
self.config_internal = SessionConfigInternal(config_options)
119+
self._python_table_lookup = False
108120

109121
def with_create_default_catalog_and_schema(
110122
self, enabled: bool = True
@@ -274,6 +286,11 @@ def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig:
274286
self.config_internal = self.config_internal.with_parquet_pruning(enabled)
275287
return self
276288

289+
def with_python_table_lookup(self, enabled: bool = True) -> SessionConfig:
290+
"""Enable implicit table lookup for Python objects when running SQL."""
291+
self._python_table_lookup = enabled
292+
return self
293+
277294
def set(self, key: str, value: str) -> SessionConfig:
278295
"""Set a configuration option.
279296
@@ -513,11 +530,17 @@ def __init__(
513530
ctx = SessionContext()
514531
df = ctx.read_csv("data.csv")
515532
"""
516-
config = config.config_internal if config is not None else None
517-
runtime = runtime.config_internal if runtime is not None else None
533+
python_table_lookup = auto_register_python_variables # Use parameter as default
534+
if config is not None:
535+
python_table_lookup = config._python_table_lookup
536+
config_internal = config.config_internal
537+
else:
538+
config_internal = None
539+
540+
runtime_internal = runtime.config_internal if runtime is not None else None
518541

519-
self.ctx = SessionContextInternal(config, runtime)
520-
self._auto_register_python_variables = auto_register_python_variables
542+
self.ctx = SessionContextInternal(config_internal, runtime_internal)
543+
self._python_table_lookup = python_table_lookup
521544

522545
def __repr__(self) -> str:
523546
"""Print a string representation of the Session Context."""
@@ -544,17 +567,27 @@ def enable_url_table(self) -> SessionContext:
544567
klass = self.__class__
545568
obj = klass.__new__(klass)
546569
obj.ctx = self.ctx.enable_url_table()
547-
obj._auto_register_python_variables = self._auto_register_python_variables
570+
obj._python_table_lookup = self._python_table_lookup
548571
return obj
549572

573+
def set_python_table_lookup(self, enabled: bool) -> None:
574+
"""Enable or disable implicit table lookup for Python objects."""
575+
self._python_table_lookup = enabled
576+
577+
# Backward compatibility properties
550578
@property
551579
def auto_register_python_variables(self) -> bool:
552580
"""Toggle automatic registration of Python variables in SQL queries."""
553-
return self._auto_register_python_variables
581+
return self._python_table_lookup
554582

555583
@auto_register_python_variables.setter
556584
def auto_register_python_variables(self, enabled: bool) -> None:
557-
self._auto_register_python_variables = bool(enabled)
585+
self._python_table_lookup = bool(enabled)
586+
587+
def _extract_missing_table_names(self, error: Exception) -> set[str]:
588+
"""Extract missing table names from error (backward compatibility)."""
589+
missing_table = self._extract_missing_table_name(error)
590+
return {missing_table} if missing_table else set()
558591

559592
def register_object_store(
560593
self, schema: str, store: Any, host: str | None = None
@@ -620,12 +653,29 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
620653
Returns:
621654
DataFrame representation of the SQL query.
622655
"""
623-
options_internal = None if options is None else options.options_internal
624-
return self._sql_with_retry(
625-
query,
626-
options_internal,
627-
self._auto_register_python_variables,
628-
)
656+
attempted_missing_tables: set[str] = set()
657+
658+
while True:
659+
try:
660+
if options is None:
661+
result = self.ctx.sql(query)
662+
else:
663+
result = self.ctx.sql_with_options(query, options.options_internal)
664+
except Exception as exc:
665+
missing_table = self._extract_missing_table_name(exc)
666+
if (
667+
missing_table is None
668+
or missing_table in attempted_missing_tables
669+
or not self._python_table_lookup
670+
):
671+
raise
672+
673+
attempted_missing_tables.add(missing_table)
674+
if not self._register_missing_table_from_callers(missing_table):
675+
raise
676+
continue
677+
678+
return DataFrame(result)
629679

630680
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
631681
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
@@ -642,137 +692,122 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
642692
"""
643693
return self.sql(query, options)
644694

645-
def _sql_with_retry(
646-
self,
647-
query: str,
648-
options_internal: SQLOptionsInternal | None,
649-
allow_retry: bool,
650-
) -> DataFrame:
651-
try:
652-
if options_internal is None:
653-
return DataFrame(self.ctx.sql(query))
654-
return DataFrame(self.ctx.sql_with_options(query, options_internal))
655-
except Exception as exc:
656-
if not allow_retry or not self._handle_missing_table_error(exc):
657-
raise
658-
return self._sql_with_retry(query, options_internal, allow_retry)
659-
660-
def _handle_missing_table_error(self, error: Exception) -> bool:
661-
missing_tables = self._extract_missing_table_names(error)
662-
if not missing_tables:
663-
return False
664-
665-
registered_any = False
666-
attempted: set[str] = set()
667-
for raw_name in missing_tables:
668-
for candidate in self._candidate_table_names(raw_name):
669-
if candidate in attempted:
670-
continue
671-
attempted.add(candidate)
672-
673-
value = self._lookup_python_variable(candidate)
674-
if value is None:
675-
continue
676-
if self._register_python_value(candidate, value):
677-
registered_any = True
678-
break
679-
return registered_any
680-
681-
def _candidate_table_names(self, identifier: str) -> Iterator[str]:
682-
cleaned = identifier.strip().strip('"')
683-
if not cleaned:
684-
return
685-
686-
seen: set[str] = set()
687-
candidates = [cleaned]
688-
if "." in cleaned:
689-
candidates.append(cleaned.rsplit(".", 1)[-1])
690-
691-
for candidate in candidates:
692-
normalized = candidate.strip()
693-
if not normalized or normalized in seen:
694-
continue
695-
seen.add(normalized)
696-
yield normalized
697-
698-
def _extract_missing_table_names(self, error: Exception) -> set[str]:
699-
names: set[str] = set()
700-
attribute = getattr(error, "missing_table_names", None)
701-
if attribute is not None:
702-
if isinstance(attribute, (list, tuple, set, frozenset)):
703-
for item in attribute:
704-
if item is None:
705-
continue
706-
for candidate in self._candidate_table_names(str(item)):
707-
names.add(candidate)
708-
elif attribute is not None:
709-
for candidate in self._candidate_table_names(str(attribute)):
710-
names.add(candidate)
711-
if names:
712-
return names
713-
695+
@staticmethod
696+
def _extract_missing_table_name(error: Exception) -> str | None:
714697
message = str(error)
715-
return {match.group(1) for match in _MISSING_TABLE_PATTERN.finditer(message)}
698+
patterns = (
699+
r"table '([^']+)' not found",
700+
r"Table not found: ['\"]?([^\s'\"]+)['\"]?",
701+
r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found",
702+
r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?",
703+
)
704+
for pattern in patterns:
705+
if match := re.search(pattern, message):
706+
return match.group(1)
707+
return None
716708

717-
def _lookup_python_variable(self, name: str) -> Any | None:
709+
def _register_missing_table_from_callers(self, table_name: str) -> bool:
718710
frame = inspect.currentframe()
719-
outer = frame.f_back if frame is not None else None
720-
lower_name = name.lower()
711+
if frame is None:
712+
return False
721713

722714
try:
723-
while outer is not None:
724-
for mapping in (outer.f_locals, outer.f_globals):
725-
if not mapping:
726-
continue
727-
if name in mapping:
728-
value = mapping[name]
729-
if value is not None:
730-
return value
731-
# allow outer scopes to provide a non-``None`` value
732-
continue
733-
for key, value in mapping.items():
734-
if value is None:
735-
continue
736-
if key == name or key.lower() == lower_name:
737-
return value
738-
outer = outer.f_back
715+
frame = frame.f_back
716+
if frame is None:
717+
return False
718+
frame = frame.f_back
719+
while frame is not None:
720+
if self._register_from_namespace(table_name, frame.f_locals):
721+
return True
722+
if self._register_from_namespace(table_name, frame.f_globals):
723+
return True
724+
frame = frame.f_back
739725
finally:
740-
del outer
741726
del frame
742-
return None
727+
return False
743728

744-
def _register_python_value(self, table_name: str, value: Any) -> bool:
745-
if value is None:
729+
def _register_from_namespace(
730+
self, table_name: str, namespace: dict[str, Any]
731+
) -> bool:
732+
if table_name not in namespace:
746733
return False
734+
value = namespace[table_name]
735+
return self._register_python_value(table_name, value)
736+
737+
def _register_python_value(self, table_name: str, value: Any) -> bool:
738+
pandas = _load_optional_module("pandas")
739+
polars = _load_optional_module("polars")
740+
polars_df = getattr(polars, "DataFrame", None) if polars is not None else None
741+
742+
handlers = (
743+
(isinstance(value, DataFrame), self._register_datafusion_dataframe),
744+
(
745+
isinstance(value, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)),
746+
self._register_arrow_object,
747+
),
748+
(
749+
pandas is not None and isinstance(value, pandas.DataFrame),
750+
self._register_pandas_dataframe,
751+
),
752+
(
753+
polars_df is not None and isinstance(value, polars_df),
754+
self._register_polars_dataframe,
755+
),
756+
)
757+
758+
for matches, handler in handlers:
759+
if matches:
760+
return handler(table_name, value)
761+
762+
return False
747763

748-
registered = False
749-
if isinstance(value, DataFrame):
764+
def _register_datafusion_dataframe(self, table_name: str, value: DataFrame) -> bool:
765+
try:
750766
self.register_view(table_name, value)
751-
registered = True
752-
elif isinstance(value, Table):
753-
self.register_table(table_name, value)
754-
registered = True
755-
else:
756-
provider = getattr(value, "__datafusion_table_provider__", None)
757-
if callable(provider):
758-
self.register_table_provider(table_name, value)
759-
registered = True
760-
elif hasattr(value, "__arrow_c_stream__") or hasattr(
761-
value, "__arrow_c_array__"
762-
):
763-
self.from_arrow(value, name=table_name)
764-
registered = True
765-
else:
766-
module_name = getattr(type(value), "__module__", "") or ""
767-
class_name = getattr(type(value), "__name__", "") or ""
768-
if module_name.startswith("pandas.") and class_name == "DataFrame":
769-
self.from_pandas(value, name=table_name)
770-
registered = True
771-
elif module_name.startswith("polars") and class_name == "DataFrame":
772-
self.from_polars(value, name=table_name)
773-
registered = True
774-
775-
return registered
767+
except Exception as exc: # noqa: BLE001
768+
warnings.warn(
769+
"Failed to register DataFusion DataFrame for table "
770+
f"'{table_name}': {exc}",
771+
stacklevel=4,
772+
)
773+
return False
774+
return True
775+
776+
def _register_arrow_object(self, table_name: str, value: Any) -> bool:
777+
try:
778+
self.from_arrow(value, table_name)
779+
except Exception as exc: # noqa: BLE001
780+
warnings.warn(
781+
"Failed to register Arrow data for table "
782+
f"'{table_name}': {exc}",
783+
stacklevel=4,
784+
)
785+
return False
786+
return True
787+
788+
def _register_pandas_dataframe(self, table_name: str, value: Any) -> bool:
789+
try:
790+
self.from_pandas(value, table_name)
791+
except Exception as exc: # noqa: BLE001
792+
warnings.warn(
793+
"Failed to register pandas DataFrame for table "
794+
f"'{table_name}': {exc}",
795+
stacklevel=4,
796+
)
797+
return False
798+
return True
799+
800+
def _register_polars_dataframe(self, table_name: str, value: Any) -> bool:
801+
try:
802+
self.from_polars(value, table_name)
803+
except Exception as exc: # noqa: BLE001
804+
warnings.warn(
805+
"Failed to register polars DataFrame for table "
806+
f"'{table_name}': {exc}",
807+
stacklevel=4,
808+
)
809+
return False
810+
return True
776811

777812
def create_dataframe(
778813
self,

0 commit comments

Comments
 (0)