From 66d74a38d62de790908650dbaba1ab7993cf1ee5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 15:42:11 +0800 Subject: [PATCH 01/20] feat: add automatic variable registration for Arrow-compatible Python objects in SQL queries --- docs/source/user-guide/sql.rst | 27 +++++- python/datafusion/context.py | 163 ++++++++++++++++++++++++++++++++- python/tests/test_context.py | 65 ++++++++++++- src/context.rs | 85 ++++++++++++++++- 4 files changed, 329 insertions(+), 11 deletions(-) diff --git a/docs/source/user-guide/sql.rst b/docs/source/user-guide/sql.rst index 6fa7f0c6a..9556bf3e0 100644 --- a/docs/source/user-guide/sql.rst +++ b/docs/source/user-guide/sql.rst @@ -36,4 +36,29 @@ DataFusion also offers a SQL API, read the full reference `here 2") + print(result.to_pandas()) + +The feature inspects the call stack for variables whose names match missing +tables and registers them if they expose Arrow data (including pandas and +Polars DataFrames). Existing contexts can enable or disable the behavior at +runtime through the :py:attr:`SessionContext.auto_register_python_variables` +property. diff --git a/python/datafusion/context.py b/python/datafusion/context.py index b6e728b51..59cf575ce 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,8 +19,10 @@ from __future__ import annotations +import inspect +import re import warnings -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Iterator, Protocol try: from warnings import deprecated # Python 3.13+ @@ -41,6 +43,8 @@ from ._internal import SQLOptions as SQLOptionsInternal from ._internal import expr as expr_internal +_MISSING_TABLE_PATTERN = re.compile(r"(?i)(?:table|view) '([^']+)' not found") + if TYPE_CHECKING: import pathlib from collections.abc import Sequence @@ -483,6 +487,8 @@ def __init__( self, config: SessionConfig | None = None, runtime: RuntimeEnvBuilder | None = None, + *, + auto_register_python_variables: bool = False, ) -> None: """Main interface for executing queries with DataFusion. @@ -493,6 +499,9 @@ def __init__( Args: config: Session configuration options. runtime: Runtime configuration options. + auto_register_python_variables: Automatically register Arrow-like + Python objects referenced in SQL queries when they are available + in the caller's scope. Example usage: @@ -508,6 +517,7 @@ def __init__( runtime = runtime.config_internal if runtime is not None else None self.ctx = SessionContextInternal(config, runtime) + self._auto_register_python_variables = auto_register_python_variables def __repr__(self) -> str: """Print a string representation of the Session Context.""" @@ -534,8 +544,18 @@ def enable_url_table(self) -> SessionContext: klass = self.__class__ obj = klass.__new__(klass) obj.ctx = self.ctx.enable_url_table() + obj._auto_register_python_variables = self._auto_register_python_variables return obj + @property + def auto_register_python_variables(self) -> bool: + """Toggle automatic registration of Python variables in SQL queries.""" + return self._auto_register_python_variables + + @auto_register_python_variables.setter + def auto_register_python_variables(self, enabled: bool) -> None: + self._auto_register_python_variables = bool(enabled) + def register_object_store( self, schema: str, store: Any, host: str | None = None ) -> None: @@ -600,9 +620,12 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Returns: DataFrame representation of the SQL query. """ - if options is None: - return DataFrame(self.ctx.sql(query)) - return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) + options_internal = None if options is None else options.options_internal + return self._sql_with_retry( + query, + options_internal, + self._auto_register_python_variables, + ) def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. @@ -619,6 +642,138 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """ return self.sql(query, options) + def _sql_with_retry( + self, + query: str, + options_internal: SQLOptionsInternal | None, + allow_retry: bool, + ) -> DataFrame: + try: + if options_internal is None: + return DataFrame(self.ctx.sql(query)) + return DataFrame(self.ctx.sql_with_options(query, options_internal)) + except Exception as exc: + if not allow_retry or not self._handle_missing_table_error(exc): + raise + return self._sql_with_retry(query, options_internal, allow_retry) + + def _handle_missing_table_error(self, error: Exception) -> bool: + missing_tables = self._extract_missing_table_names(error) + if not missing_tables: + return False + + registered_any = False + attempted: set[str] = set() + for raw_name in missing_tables: + for candidate in self._candidate_table_names(raw_name): + if candidate in attempted: + continue + attempted.add(candidate) + + value = self._lookup_python_variable(candidate) + if value is None: + continue + if self._register_python_value(candidate, value): + registered_any = True + break + return registered_any + + def _candidate_table_names(self, identifier: str) -> Iterator[str]: + cleaned = identifier.strip().strip('"') + if not cleaned: + return + + seen: set[str] = set() + candidates = [cleaned] + if "." in cleaned: + candidates.append(cleaned.rsplit(".", 1)[-1]) + + for candidate in candidates: + normalized = candidate.strip() + if not normalized or normalized in seen: + continue + seen.add(normalized) + yield normalized + + def _extract_missing_table_names(self, error: Exception) -> set[str]: + names: set[str] = set() + attribute = getattr(error, "missing_table_names", None) + if attribute is not None: + if isinstance(attribute, (list, tuple, set, frozenset)): + for item in attribute: + if item is None: + continue + for candidate in self._candidate_table_names(str(item)): + names.add(candidate) + elif attribute is not None: + for candidate in self._candidate_table_names(str(attribute)): + names.add(candidate) + if names: + return names + + message = str(error) + return {match.group(1) for match in _MISSING_TABLE_PATTERN.finditer(message)} + + def _lookup_python_variable(self, name: str) -> Any | None: + frame = inspect.currentframe() + outer = frame.f_back if frame is not None else None + lower_name = name.lower() + + try: + while outer is not None: + for mapping in (outer.f_locals, outer.f_globals): + if not mapping: + continue + if name in mapping: + value = mapping[name] + if value is not None: + return value + # allow outer scopes to provide a non-``None`` value + continue + for key, value in mapping.items(): + if value is None: + continue + if key == name or key.lower() == lower_name: + return value + outer = outer.f_back + finally: + del outer + del frame + return None + + def _register_python_value(self, table_name: str, value: Any) -> bool: + if value is None: + return False + + registered = False + if isinstance(value, DataFrame): + self.register_view(table_name, value) + registered = True + elif isinstance(value, Table): + self.register_table(table_name, value) + registered = True + else: + provider = getattr(value, "__datafusion_table_provider__", None) + if callable(provider): + self.register_table_provider(table_name, value) + registered = True + elif hasattr(value, "__arrow_c_stream__") or hasattr( + value, "__arrow_c_array__" + ): + self.from_arrow(value, name=table_name) + registered = True + else: + module_name = getattr(type(value), "__module__", "") or "" + class_name = getattr(type(value), "__name__", "") or "" + if module_name.startswith("pandas.") and class_name == "DataFrame": + self.from_pandas(value, name=table_name) + registered = True + elif module_name.startswith("polars") and class_name == "DataFrame": + self.from_polars(value, name=table_name) + registered = True + + return registered + def create_dataframe( self, partitions: list[list[pa.RecordBatch]], diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6dbcc0d5e..5a010dcf6 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -255,6 +255,69 @@ def test_from_pylist(ctx): assert df.collect()[0].num_rows == 3 +def test_sql_missing_table_without_auto_register(ctx): + arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 + + with pytest.raises(Exception, match="not found") as excinfo: + ctx.sql("SELECT * FROM arrow_table").collect() + + missing = getattr(excinfo.value, "missing_table_names", None) + assert missing is not None + assert "arrow_table" in set(ctx._extract_missing_table_names(excinfo.value)) + + +def test_sql_auto_register_arrow_table(): + ctx = SessionContext(auto_register_python_variables=True) + arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 + + result = ctx.sql( + "SELECT SUM(value) AS total FROM arrow_table", + ).collect() + + assert ctx.table_exist("arrow_table") + assert result[0].column(0).to_pylist()[0] == 6 + + +def test_sql_auto_register_arrow_outer_scope(): + ctx = SessionContext() + ctx.auto_register_python_variables = True + arrow_table = pa.Table.from_pydict({"value": [1, 2, 3, 4]}) # noqa: F841 + + def run_query(): + return ctx.sql( + "SELECT COUNT(*) AS total_rows FROM arrow_table", + ).collect() + + result = run_query() + assert result[0].column(0).to_pylist()[0] == 4 + + +def test_sql_auto_register_pandas_dataframe(): + pd = pytest.importorskip("pandas") + + ctx = SessionContext(auto_register_python_variables=True) + pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) # noqa: F841 + + result = ctx.sql( + "SELECT AVG(value) AS avg_value FROM pandas_df", + ).collect() + + assert pytest.approx(result[0].column(0).to_pylist()[0]) == 2.5 + + +def test_sql_auto_register_polars_dataframe(): + pl = pytest.importorskip("polars") + + ctx = SessionContext(auto_register_python_variables=True) + polars_df = pl.DataFrame({"value": [2, 4, 6]}) # noqa: F841 + + result = ctx.sql( + "SELECT MIN(value) AS min_value FROM polars_df", + ).collect() + + assert result[0].column(0).to_pylist()[0] == 2 + + def test_from_pydict(ctx): # create a dataframe from Python dictionary data = {"a": [1, 2, 3], "b": [4, 5, 6]} @@ -484,7 +547,7 @@ def test_table_exist(ctx): def test_table_not_found(ctx): - from uuid import uuid4 + from uuid import uuid4 # noqa: PLC0415 with pytest.raises(KeyError): ctx.table(f"not-found-{uuid4()}") diff --git a/src/context.rs b/src/context.rs index 36133a33d..9dd9db37a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -34,7 +34,7 @@ use pyo3::prelude::*; use crate::catalog::{PyCatalog, PyTable, RustWrappedPyCatalogProvider}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; -use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; +use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; @@ -59,6 +59,7 @@ use datafusion::datasource::listing::{ }; use datafusion::datasource::MemTable; use datafusion::datasource::TableProvider; +use datafusion::error::DataFusionError; use datafusion::execution::context::{ DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext, }; @@ -435,8 +436,11 @@ impl PySessionContext { /// Returns a PyDataFrame whose plan corresponds to the SQL statement. pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult { let result = self.ctx.sql(query); - let df = wait_for_future(py, result)??; - Ok(PyDataFrame::new(df)) + match wait_for_future(py, result) { + Ok(Ok(df)) => Ok(PyDataFrame::new(df)), + Ok(Err(err)) => Err(py_datafusion_error_with_missing_tables(py, err)), + Err(py_err) => Err(PyDataFusionError::PythonError(py_err)), + } } #[pyo3(signature = (query, options=None))] @@ -452,8 +456,11 @@ impl PySessionContext { SQLOptions::new() }; let result = self.ctx.sql_with_options(query, options); - let df = wait_for_future(py, result)??; - Ok(PyDataFrame::new(df)) + match wait_for_future(py, result) { + Ok(Ok(df)) => Ok(PyDataFrame::new(df)), + Ok(Err(err)) => Err(py_datafusion_error_with_missing_tables(py, err)), + Err(py_err) => Err(PyDataFusionError::PythonError(py_err)), + } } #[pyo3(signature = (partitions, name=None, schema=None))] @@ -1188,6 +1195,74 @@ impl PySessionContext { } } +fn py_datafusion_error_with_missing_tables(py: Python, err: DataFusionError) -> PyDataFusionError { + let missing_tables = collect_missing_table_names(&err); + let py_err: PyErr = PyDataFusionError::from(err).into(); + + if !missing_tables.is_empty() { + if let Ok(py_names) = PyList::new(py, &missing_tables) { + let _ = py_err + .value(py) + .setattr("missing_table_names", py_names.into_any()); + } + } + + PyDataFusionError::PythonError(py_err) +} + +fn collect_missing_table_names(err: &DataFusionError) -> Vec { + let mut names = HashSet::new(); + collect_missing_table_names_recursive(err, &mut names); + + let mut collected: Vec = names.into_iter().collect(); + collected.sort(); + collected +} + +fn collect_missing_table_names_recursive(err: &DataFusionError, acc: &mut HashSet) { + match err { + DataFusionError::Plan(message) + | DataFusionError::Execution(message) + | DataFusionError::Configuration(message) + | DataFusionError::NotImplemented(message) + | DataFusionError::ResourcesExhausted(message) + | DataFusionError::Internal(message) => { + parse_missing_table_names_in_message(message, acc); + } + DataFusionError::Context(_, inner) | DataFusionError::Diagnostic(_, inner) => { + collect_missing_table_names_recursive(inner, acc); + } + _ => {} + } +} + +fn parse_missing_table_names_in_message(message: &str, acc: &mut HashSet) { + const LOOKUPS: [(&str, char); 4] = [ + ("table '", '\''), + ("view '", '\''), + ("table \"", '"'), + ("view \"", '"'), + ]; + + let lower = message.to_ascii_lowercase(); + for (needle, terminator) in LOOKUPS { + let mut search_start = 0usize; + while let Some(relative) = lower[search_start..].find(needle) { + let start = search_start + relative + needle.len(); + let remainder = &message[start..]; + if let Some(end) = remainder.find(terminator) { + let name = &remainder[..end]; + if !name.is_empty() { + acc.insert(name.to_string()); + } + search_start = start + end + 1; + } else { + break; + } + } + } +} + pub fn parse_file_compression_type( file_compression_type: Option, ) -> Result { From 65e44928cca6e87781b848ea454eb9fbc71807af Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 15:48:36 +0800 Subject: [PATCH 02/20] fix: remove noqa directive for uuid4 import in test_table_not_found --- python/tests/test_context.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5a010dcf6..48fa9ed9e 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -547,7 +547,7 @@ def test_table_exist(ctx): def test_table_not_found(ctx): - from uuid import uuid4 # noqa: PLC0415 + from uuid import uuid4 with pytest.raises(KeyError): ctx.table(f"not-found-{uuid4()}") From 53a62f780cee94cb4a334fcc9b3e7405140c65c7 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 16:52:46 +0800 Subject: [PATCH 03/20] feat: enable implicit table lookup for Python objects in SQL queries --- python/datafusion/context.py | 301 +++++++++++++++++++---------------- 1 file changed, 168 insertions(+), 133 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 59cf575ce..e41199cce 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,9 +19,11 @@ from __future__ import annotations +import importlib import inspect import re import warnings +from functools import cache from typing import TYPE_CHECKING, Any, Iterator, Protocol try: @@ -55,6 +57,15 @@ from datafusion.plan import ExecutionPlan, LogicalPlan +@cache +def _load_optional_module(module_name: str) -> Any | None: + """Return the module for *module_name* if it can be imported.""" + try: + return importlib.import_module(module_name) + except ModuleNotFoundError: + return None + + class ArrowStreamExportable(Protocol): """Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface. @@ -105,6 +116,7 @@ def __init__(self, config_options: dict[str, str] | None = None) -> None: config_options: Configuration options. """ self.config_internal = SessionConfigInternal(config_options) + self._python_table_lookup = False def with_create_default_catalog_and_schema( self, enabled: bool = True @@ -274,6 +286,11 @@ def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig: self.config_internal = self.config_internal.with_parquet_pruning(enabled) return self + def with_python_table_lookup(self, enabled: bool = True) -> SessionConfig: + """Enable implicit table lookup for Python objects when running SQL.""" + self._python_table_lookup = enabled + return self + def set(self, key: str, value: str) -> SessionConfig: """Set a configuration option. @@ -513,11 +530,17 @@ def __init__( ctx = SessionContext() df = ctx.read_csv("data.csv") """ - config = config.config_internal if config is not None else None - runtime = runtime.config_internal if runtime is not None else None + python_table_lookup = auto_register_python_variables # Use parameter as default + if config is not None: + python_table_lookup = config._python_table_lookup + config_internal = config.config_internal + else: + config_internal = None + + runtime_internal = runtime.config_internal if runtime is not None else None - self.ctx = SessionContextInternal(config, runtime) - self._auto_register_python_variables = auto_register_python_variables + self.ctx = SessionContextInternal(config_internal, runtime_internal) + self._python_table_lookup = python_table_lookup def __repr__(self) -> str: """Print a string representation of the Session Context.""" @@ -544,17 +567,27 @@ def enable_url_table(self) -> SessionContext: klass = self.__class__ obj = klass.__new__(klass) obj.ctx = self.ctx.enable_url_table() - obj._auto_register_python_variables = self._auto_register_python_variables + obj._python_table_lookup = self._python_table_lookup return obj + def set_python_table_lookup(self, enabled: bool) -> None: + """Enable or disable implicit table lookup for Python objects.""" + self._python_table_lookup = enabled + + # Backward compatibility properties @property def auto_register_python_variables(self) -> bool: """Toggle automatic registration of Python variables in SQL queries.""" - return self._auto_register_python_variables + return self._python_table_lookup @auto_register_python_variables.setter def auto_register_python_variables(self, enabled: bool) -> None: - self._auto_register_python_variables = bool(enabled) + self._python_table_lookup = bool(enabled) + + def _extract_missing_table_names(self, error: Exception) -> set[str]: + """Extract missing table names from error (backward compatibility).""" + missing_table = self._extract_missing_table_name(error) + return {missing_table} if missing_table else set() def register_object_store( self, schema: str, store: Any, host: str | None = None @@ -620,12 +653,29 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Returns: DataFrame representation of the SQL query. """ - options_internal = None if options is None else options.options_internal - return self._sql_with_retry( - query, - options_internal, - self._auto_register_python_variables, - ) + attempted_missing_tables: set[str] = set() + + while True: + try: + if options is None: + result = self.ctx.sql(query) + else: + result = self.ctx.sql_with_options(query, options.options_internal) + except Exception as exc: + missing_table = self._extract_missing_table_name(exc) + if ( + missing_table is None + or missing_table in attempted_missing_tables + or not self._python_table_lookup + ): + raise + + attempted_missing_tables.add(missing_table) + if not self._register_missing_table_from_callers(missing_table): + raise + continue + + return DataFrame(result) def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. @@ -642,137 +692,122 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """ return self.sql(query, options) - def _sql_with_retry( - self, - query: str, - options_internal: SQLOptionsInternal | None, - allow_retry: bool, - ) -> DataFrame: - try: - if options_internal is None: - return DataFrame(self.ctx.sql(query)) - return DataFrame(self.ctx.sql_with_options(query, options_internal)) - except Exception as exc: - if not allow_retry or not self._handle_missing_table_error(exc): - raise - return self._sql_with_retry(query, options_internal, allow_retry) - - def _handle_missing_table_error(self, error: Exception) -> bool: - missing_tables = self._extract_missing_table_names(error) - if not missing_tables: - return False - - registered_any = False - attempted: set[str] = set() - for raw_name in missing_tables: - for candidate in self._candidate_table_names(raw_name): - if candidate in attempted: - continue - attempted.add(candidate) - - value = self._lookup_python_variable(candidate) - if value is None: - continue - if self._register_python_value(candidate, value): - registered_any = True - break - return registered_any - - def _candidate_table_names(self, identifier: str) -> Iterator[str]: - cleaned = identifier.strip().strip('"') - if not cleaned: - return - - seen: set[str] = set() - candidates = [cleaned] - if "." in cleaned: - candidates.append(cleaned.rsplit(".", 1)[-1]) - - for candidate in candidates: - normalized = candidate.strip() - if not normalized or normalized in seen: - continue - seen.add(normalized) - yield normalized - - def _extract_missing_table_names(self, error: Exception) -> set[str]: - names: set[str] = set() - attribute = getattr(error, "missing_table_names", None) - if attribute is not None: - if isinstance(attribute, (list, tuple, set, frozenset)): - for item in attribute: - if item is None: - continue - for candidate in self._candidate_table_names(str(item)): - names.add(candidate) - elif attribute is not None: - for candidate in self._candidate_table_names(str(attribute)): - names.add(candidate) - if names: - return names - + @staticmethod + def _extract_missing_table_name(error: Exception) -> str | None: message = str(error) - return {match.group(1) for match in _MISSING_TABLE_PATTERN.finditer(message)} + patterns = ( + r"table '([^']+)' not found", + r"Table not found: ['\"]?([^\s'\"]+)['\"]?", + r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found", + r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?", + ) + for pattern in patterns: + if match := re.search(pattern, message): + return match.group(1) + return None - def _lookup_python_variable(self, name: str) -> Any | None: + def _register_missing_table_from_callers(self, table_name: str) -> bool: frame = inspect.currentframe() - outer = frame.f_back if frame is not None else None - lower_name = name.lower() + if frame is None: + return False try: - while outer is not None: - for mapping in (outer.f_locals, outer.f_globals): - if not mapping: - continue - if name in mapping: - value = mapping[name] - if value is not None: - return value - # allow outer scopes to provide a non-``None`` value - continue - for key, value in mapping.items(): - if value is None: - continue - if key == name or key.lower() == lower_name: - return value - outer = outer.f_back + frame = frame.f_back + if frame is None: + return False + frame = frame.f_back + while frame is not None: + if self._register_from_namespace(table_name, frame.f_locals): + return True + if self._register_from_namespace(table_name, frame.f_globals): + return True + frame = frame.f_back finally: - del outer del frame - return None + return False - def _register_python_value(self, table_name: str, value: Any) -> bool: - if value is None: + def _register_from_namespace( + self, table_name: str, namespace: dict[str, Any] + ) -> bool: + if table_name not in namespace: return False + value = namespace[table_name] + return self._register_python_value(table_name, value) + + def _register_python_value(self, table_name: str, value: Any) -> bool: + pandas = _load_optional_module("pandas") + polars = _load_optional_module("polars") + polars_df = getattr(polars, "DataFrame", None) if polars is not None else None + + handlers = ( + (isinstance(value, DataFrame), self._register_datafusion_dataframe), + ( + isinstance(value, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)), + self._register_arrow_object, + ), + ( + pandas is not None and isinstance(value, pandas.DataFrame), + self._register_pandas_dataframe, + ), + ( + polars_df is not None and isinstance(value, polars_df), + self._register_polars_dataframe, + ), + ) + + for matches, handler in handlers: + if matches: + return handler(table_name, value) + + return False - registered = False - if isinstance(value, DataFrame): + def _register_datafusion_dataframe(self, table_name: str, value: DataFrame) -> bool: + try: self.register_view(table_name, value) - registered = True - elif isinstance(value, Table): - self.register_table(table_name, value) - registered = True - else: - provider = getattr(value, "__datafusion_table_provider__", None) - if callable(provider): - self.register_table_provider(table_name, value) - registered = True - elif hasattr(value, "__arrow_c_stream__") or hasattr( - value, "__arrow_c_array__" - ): - self.from_arrow(value, name=table_name) - registered = True - else: - module_name = getattr(type(value), "__module__", "") or "" - class_name = getattr(type(value), "__name__", "") or "" - if module_name.startswith("pandas.") and class_name == "DataFrame": - self.from_pandas(value, name=table_name) - registered = True - elif module_name.startswith("polars") and class_name == "DataFrame": - self.from_polars(value, name=table_name) - registered = True - - return registered + except Exception as exc: # noqa: BLE001 + warnings.warn( + "Failed to register DataFusion DataFrame for table " + f"'{table_name}': {exc}", + stacklevel=4, + ) + return False + return True + + def _register_arrow_object(self, table_name: str, value: Any) -> bool: + try: + self.from_arrow(value, table_name) + except Exception as exc: # noqa: BLE001 + warnings.warn( + "Failed to register Arrow data for table " + f"'{table_name}': {exc}", + stacklevel=4, + ) + return False + return True + + def _register_pandas_dataframe(self, table_name: str, value: Any) -> bool: + try: + self.from_pandas(value, table_name) + except Exception as exc: # noqa: BLE001 + warnings.warn( + "Failed to register pandas DataFrame for table " + f"'{table_name}': {exc}", + stacklevel=4, + ) + return False + return True + + def _register_polars_dataframe(self, table_name: str, value: Any) -> bool: + try: + self.from_polars(value, table_name) + except Exception as exc: # noqa: BLE001 + warnings.warn( + "Failed to register polars DataFrame for table " + f"'{table_name}': {exc}", + stacklevel=4, + ) + return False + return True def create_dataframe( self, From 1f36102abc73be92e56857b42382b50d1cc02eec Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 17:18:21 +0800 Subject: [PATCH 04/20] feat: enhance table name extraction and add tests for local Arrow, Pandas, and Polars dataframes --- python/datafusion/context.py | 45 ++++++++++++++++++++++++--------- python/tests/test_context.py | 48 +++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index e41199cce..61479a542 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -692,21 +692,35 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """ return self.sql(query, options) - @staticmethod - def _extract_missing_table_name(error: Exception) -> str | None: + def _extract_missing_table_name(self, error: Exception) -> str | None: + """Return the missing table name if the exception represents that error.""" message = str(error) + + # Try the global pattern first (supports both table and view, case-insensitive) + match = _MISSING_TABLE_PATTERN.search(message) + if match: + table_name = match.group(1) + # Handle dotted table names by extracting the last part + if "." in table_name: + table_name = table_name.rsplit(".", 1)[-1] + return table_name + + # Fallback to additional patterns for broader compatibility patterns = ( - r"table '([^']+)' not found", r"Table not found: ['\"]?([^\s'\"]+)['\"]?", r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found", r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?", ) for pattern in patterns: if match := re.search(pattern, message): - return match.group(1) + table_name = match.group(1) + if "." in table_name: + table_name = table_name.rsplit(".", 1)[-1] + return table_name return None def _register_missing_table_from_callers(self, table_name: str) -> bool: + """Register a supported local object from caller stack frames.""" frame = inspect.currentframe() if frame is None: return False @@ -729,12 +743,14 @@ def _register_missing_table_from_callers(self, table_name: str) -> bool: def _register_from_namespace( self, table_name: str, namespace: dict[str, Any] ) -> bool: + """Register a table from a namespace if the table name exists.""" if table_name not in namespace: return False value = namespace[table_name] return self._register_python_value(table_name, value) def _register_python_value(self, table_name: str, value: Any) -> bool: + """Register a Python object as a table if it's a supported type.""" pandas = _load_optional_module("pandas") polars = _load_optional_module("polars") polars_df = getattr(polars, "DataFrame", None) if polars is not None else None @@ -753,6 +769,11 @@ def _register_python_value(self, table_name: str, value: Any) -> bool: polars_df is not None and isinstance(value, polars_df), self._register_polars_dataframe, ), + # Support objects with Arrow C Stream interface + ( + hasattr(value, "__arrow_c_stream__") or hasattr(value, "__arrow_c_array__"), + self._register_arrow_object, + ), ) for matches, handler in handlers: @@ -762,48 +783,48 @@ def _register_python_value(self, table_name: str, value: Any) -> bool: return False def _register_datafusion_dataframe(self, table_name: str, value: DataFrame) -> bool: + """Register a DataFusion DataFrame as a view.""" try: self.register_view(table_name, value) except Exception as exc: # noqa: BLE001 warnings.warn( - "Failed to register DataFusion DataFrame for table " - f"'{table_name}': {exc}", + f"Failed to register DataFusion DataFrame for table '{table_name}': {exc}", stacklevel=4, ) return False return True def _register_arrow_object(self, table_name: str, value: Any) -> bool: + """Register an Arrow object (Table, RecordBatch, RecordBatchReader, or stream).""" try: self.from_arrow(value, table_name) except Exception as exc: # noqa: BLE001 warnings.warn( - "Failed to register Arrow data for table " - f"'{table_name}': {exc}", + f"Failed to register Arrow data for table '{table_name}': {exc}", stacklevel=4, ) return False return True def _register_pandas_dataframe(self, table_name: str, value: Any) -> bool: + """Register a pandas DataFrame.""" try: self.from_pandas(value, table_name) except Exception as exc: # noqa: BLE001 warnings.warn( - "Failed to register pandas DataFrame for table " - f"'{table_name}': {exc}", + f"Failed to register pandas DataFrame for table '{table_name}': {exc}", stacklevel=4, ) return False return True def _register_polars_dataframe(self, table_name: str, value: Any) -> bool: + """Register a polars DataFrame.""" try: self.from_polars(value, table_name) except Exception as exc: # noqa: BLE001 warnings.warn( - "Failed to register polars DataFrame for table " - f"'{table_name}': {exc}", + f"Failed to register polars DataFrame for table '{table_name}': {exc}", stacklevel=4, ) return False diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 48fa9ed9e..531cc9f06 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -261,9 +261,9 @@ def test_sql_missing_table_without_auto_register(ctx): with pytest.raises(Exception, match="not found") as excinfo: ctx.sql("SELECT * FROM arrow_table").collect() - missing = getattr(excinfo.value, "missing_table_names", None) - assert missing is not None - assert "arrow_table" in set(ctx._extract_missing_table_names(excinfo.value)) + # Test that our extraction method works correctly + missing_tables = ctx._extract_missing_table_names(excinfo.value) + assert "arrow_table" in missing_tables def test_sql_auto_register_arrow_table(): @@ -348,6 +348,48 @@ def test_from_pandas(ctx): assert df.collect()[0].num_rows == 3 +def test_sql_from_local_arrow_table(ctx): + ctx.set_python_table_lookup(True) # Enable implicit table lookup + arrow_table = pa.Table.from_pydict({"a": [1, 2], "b": ["x", "y"]}) + + result = ctx.sql("SELECT * FROM arrow_table ORDER BY a").collect() + actual = pa.Table.from_batches(result) + expected = pa.Table.from_pydict({"a": [1, 2], "b": ["x", "y"]}) + + assert actual.equals(expected) + + +def test_sql_from_local_pandas_dataframe(ctx): + ctx.set_python_table_lookup(True) # Enable implicit table lookup + pd = pytest.importorskip("pandas") + pandas_df = pd.DataFrame({"a": [3, 1], "b": ["z", "y"]}) + + result = ctx.sql("SELECT * FROM pandas_df ORDER BY a").collect() + actual = pa.Table.from_batches(result) + expected = pa.Table.from_pydict({"a": [1, 3], "b": ["y", "z"]}) + + assert actual.equals(expected) + + +def test_sql_from_local_polars_dataframe(ctx): + ctx.set_python_table_lookup(True) # Enable implicit table lookup + pl = pytest.importorskip("polars") + polars_df = pl.DataFrame({"a": [2, 1], "b": ["beta", "alpha"]}) + + result = ctx.sql("SELECT * FROM polars_df ORDER BY a").collect() + actual = pa.Table.from_batches(result) + expected = pa.Table.from_pydict({"a": [1, 2], "b": ["alpha", "beta"]}) + + assert actual.equals(expected) + + +def test_sql_from_local_unsupported_object(ctx): + unsupported = object() + + with pytest.raises(Exception, match="table 'unsupported' not found"): + ctx.sql("SELECT * FROM unsupported").collect() + + def test_from_polars(ctx): # create a dataframe from Polars dataframe pd = pytest.importorskip("polars") From 92dde5b20b74340834d089b42dd8eb0049e12904 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 17:28:10 +0800 Subject: [PATCH 05/20] feat: enable automatic registration of Python objects in SQL queries and add corresponding tests --- docs/source/user-guide/dataframe/index.rst | 25 ++ python/datafusion/context.py | 281 ++++++++------------- python/tests/test_context.py | 40 ++- 3 files changed, 163 insertions(+), 183 deletions(-) diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 1387db0bd..177b0298d 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -228,6 +228,31 @@ Core Classes * :py:meth:`~datafusion.SessionContext.from_pandas` - Create from Pandas DataFrame * :py:meth:`~datafusion.SessionContext.from_arrow` - Create from Arrow data + ``SessionContext`` automatically resolves SQL table names that match + in-scope Python data objects. When ``auto_register_python_objects`` is + enabled (the default), a query such as ``ctx.sql("SELECT * FROM pdf")`` + will register a pandas or PyArrow object named ``pdf`` without calling + :py:meth:`~datafusion.SessionContext.from_pandas` or + :py:meth:`~datafusion.SessionContext.from_arrow` explicitly. This requires + the corresponding library (``pandas`` for pandas objects, ``pyarrow`` for + Arrow objects) to be installed. + + .. code-block:: python + + import pandas as pd + from datafusion import SessionContext + + ctx = SessionContext() + pdf = pd.DataFrame({"value": [1, 2, 3]}) + + df = ctx.sql("SELECT SUM(value) AS total FROM pdf") + print(df.to_pandas()) # automatically registers `pdf` + + To opt out, either pass ``auto_register_python_objects=False`` when + constructing the session, or call + :py:meth:`~datafusion.SessionContext.set_python_table_lookup` with + ``False`` to require explicit registration. + See: :py:class:`datafusion.SessionContext` Expression Classes diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 61479a542..a3c105db0 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -45,8 +45,6 @@ from ._internal import SQLOptions as SQLOptionsInternal from ._internal import expr as expr_internal -_MISSING_TABLE_PATTERN = re.compile(r"(?i)(?:table|view) '([^']+)' not found") - if TYPE_CHECKING: import pathlib from collections.abc import Sequence @@ -505,7 +503,7 @@ def __init__( config: SessionConfig | None = None, runtime: RuntimeEnvBuilder | None = None, *, - auto_register_python_variables: bool = False, + auto_register_python_objects: bool = True, ) -> None: """Main interface for executing queries with DataFusion. @@ -516,9 +514,9 @@ def __init__( Args: config: Session configuration options. runtime: Runtime configuration options. - auto_register_python_variables: Automatically register Arrow-like - Python objects referenced in SQL queries when they are available - in the caller's scope. + auto_register_python_objects: Automatically register referenced + Python objects (such as pandas or PyArrow data) when ``sql`` + queries reference them by name. Example usage: @@ -530,17 +528,11 @@ def __init__( ctx = SessionContext() df = ctx.read_csv("data.csv") """ - python_table_lookup = auto_register_python_variables # Use parameter as default - if config is not None: - python_table_lookup = config._python_table_lookup - config_internal = config.config_internal - else: - config_internal = None - - runtime_internal = runtime.config_internal if runtime is not None else None - - self.ctx = SessionContextInternal(config_internal, runtime_internal) - self._python_table_lookup = python_table_lookup + self.ctx = SessionContextInternal( + config.config_internal if config is not None else None, + runtime.config_internal if runtime is not None else None, + ) + self._auto_python_table_lookup = auto_register_python_objects def __repr__(self) -> str: """Print a string representation of the Session Context.""" @@ -567,27 +559,25 @@ def enable_url_table(self) -> SessionContext: klass = self.__class__ obj = klass.__new__(klass) obj.ctx = self.ctx.enable_url_table() - obj._python_table_lookup = self._python_table_lookup + obj._auto_python_table_lookup = getattr( + self, "_auto_python_table_lookup", True + ) return obj - def set_python_table_lookup(self, enabled: bool) -> None: - """Enable or disable implicit table lookup for Python objects.""" - self._python_table_lookup = enabled + def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: + """Enable or disable automatic registration of Python objects in SQL. - # Backward compatibility properties - @property - def auto_register_python_variables(self) -> bool: - """Toggle automatic registration of Python variables in SQL queries.""" - return self._python_table_lookup - - @auto_register_python_variables.setter - def auto_register_python_variables(self, enabled: bool) -> None: - self._python_table_lookup = bool(enabled) + Args: + enabled: When ``True`` (default), SQL queries automatically attempt + to resolve missing table names by looking up Python objects in + the caller's scope. When ``False``, missing tables will raise an + error unless they have been explicitly registered. - def _extract_missing_table_names(self, error: Exception) -> set[str]: - """Extract missing table names from error (backward compatibility).""" - missing_table = self._extract_missing_table_name(error) - return {missing_table} if missing_table else set() + Returns: + The current :py:class:`SessionContext` instance for chaining. + """ + self._auto_python_table_lookup = enabled + return self def register_object_store( self, schema: str, store: Any, host: str | None = None @@ -653,29 +643,28 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Returns: DataFrame representation of the SQL query. """ - attempted_missing_tables: set[str] = set() + def _execute_sql() -> DataFrame: + if options is None: + return DataFrame(self.ctx.sql(query)) + return DataFrame( + self.ctx.sql_with_options(query, options.options_internal) + ) - while True: - try: - if options is None: - result = self.ctx.sql(query) - else: - result = self.ctx.sql_with_options(query, options.options_internal) - except Exception as exc: - missing_table = self._extract_missing_table_name(exc) - if ( - missing_table is None - or missing_table in attempted_missing_tables - or not self._python_table_lookup - ): - raise - - attempted_missing_tables.add(missing_table) - if not self._register_missing_table_from_callers(missing_table): - raise - continue + try: + return _execute_sql() + except Exception as err: + if not getattr(self, "_auto_python_table_lookup", True): + raise + + missing_tables = self._extract_missing_table_names(err) + if not missing_tables: + raise - return DataFrame(result) + registered = self._register_python_tables(missing_tables) + if not registered: + raise + + return _execute_sql() def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. @@ -692,144 +681,74 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """ return self.sql(query, options) - def _extract_missing_table_name(self, error: Exception) -> str | None: - """Return the missing table name if the exception represents that error.""" - message = str(error) - - # Try the global pattern first (supports both table and view, case-insensitive) - match = _MISSING_TABLE_PATTERN.search(message) - if match: - table_name = match.group(1) - # Handle dotted table names by extracting the last part - if "." in table_name: - table_name = table_name.rsplit(".", 1)[-1] - return table_name - - # Fallback to additional patterns for broader compatibility - patterns = ( - r"Table not found: ['\"]?([^\s'\"]+)['\"]?", - r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found", - r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?", - ) - for pattern in patterns: - if match := re.search(pattern, message): - table_name = match.group(1) - if "." in table_name: - table_name = table_name.rsplit(".", 1)[-1] - return table_name - return None + @staticmethod + def _extract_missing_table_names(err: Exception) -> list[str]: + message = str(err) + matches = set() + for pattern in (r"table '([^']+)' not found", r"No table named '([^']+)'"): + matches.update(re.findall(pattern, message)) + + tables: list[str] = [] + for raw_name in matches: + if not raw_name: + continue + tables.append(raw_name.rsplit(".", 1)[-1]) + return tables - def _register_missing_table_from_callers(self, table_name: str) -> bool: - """Register a supported local object from caller stack frames.""" - frame = inspect.currentframe() - if frame is None: - return False + def _register_python_tables(self, tables: list[str]) -> bool: + registered_any = False + for table_name in tables: + if not table_name or self.table_exist(table_name): + continue + + python_obj = self._lookup_python_object(table_name) + if python_obj is None: + continue + if self._register_python_object(table_name, python_obj): + registered_any = True + + return registered_any + + @staticmethod + def _lookup_python_object(name: str) -> Any | None: + frame = inspect.currentframe() try: - frame = frame.f_back - if frame is None: - return False - frame = frame.f_back + if frame is not None: + frame = frame.f_back while frame is not None: - if self._register_from_namespace(table_name, frame.f_locals): - return True - if self._register_from_namespace(table_name, frame.f_globals): - return True + locals_dict = frame.f_locals + if name in locals_dict: + return locals_dict[name] + globals_dict = frame.f_globals + if name in globals_dict: + return globals_dict[name] frame = frame.f_back finally: del frame - return False + return None - def _register_from_namespace( - self, table_name: str, namespace: dict[str, Any] - ) -> bool: - """Register a table from a namespace if the table name exists.""" - if table_name not in namespace: - return False - value = namespace[table_name] - return self._register_python_value(table_name, value) - - def _register_python_value(self, table_name: str, value: Any) -> bool: - """Register a Python object as a table if it's a supported type.""" - pandas = _load_optional_module("pandas") - polars = _load_optional_module("polars") - polars_df = getattr(polars, "DataFrame", None) if polars is not None else None - - handlers = ( - (isinstance(value, DataFrame), self._register_datafusion_dataframe), - ( - isinstance(value, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)), - self._register_arrow_object, - ), - ( - pandas is not None and isinstance(value, pandas.DataFrame), - self._register_pandas_dataframe, - ), - ( - polars_df is not None and isinstance(value, polars_df), - self._register_polars_dataframe, - ), - # Support objects with Arrow C Stream interface - ( - hasattr(value, "__arrow_c_stream__") or hasattr(value, "__arrow_c_array__"), - self._register_arrow_object, - ), - ) + def _register_python_object(self, name: str, obj: Any) -> bool: + if isinstance(obj, DataFrame): + self.register_view(name, obj) + return True - for matches, handler in handlers: - if matches: - return handler(table_name, value) + if ( + obj.__class__.__module__.startswith("pandas.") + and obj.__class__.__name__ == "DataFrame" + ): + self.from_pandas(obj, name=name) + return True - return False + if isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)): + self.from_arrow(obj, name=name) + return True - def _register_datafusion_dataframe(self, table_name: str, value: DataFrame) -> bool: - """Register a DataFusion DataFrame as a view.""" - try: - self.register_view(table_name, value) - except Exception as exc: # noqa: BLE001 - warnings.warn( - f"Failed to register DataFusion DataFrame for table '{table_name}': {exc}", - stacklevel=4, - ) - return False - return True - - def _register_arrow_object(self, table_name: str, value: Any) -> bool: - """Register an Arrow object (Table, RecordBatch, RecordBatchReader, or stream).""" - try: - self.from_arrow(value, table_name) - except Exception as exc: # noqa: BLE001 - warnings.warn( - f"Failed to register Arrow data for table '{table_name}': {exc}", - stacklevel=4, - ) - return False - return True - - def _register_pandas_dataframe(self, table_name: str, value: Any) -> bool: - """Register a pandas DataFrame.""" - try: - self.from_pandas(value, table_name) - except Exception as exc: # noqa: BLE001 - warnings.warn( - f"Failed to register pandas DataFrame for table '{table_name}': {exc}", - stacklevel=4, - ) - return False - return True - - def _register_polars_dataframe(self, table_name: str, value: Any) -> bool: - """Register a polars DataFrame.""" - try: - self.from_polars(value, table_name) - except Exception as exc: # noqa: BLE001 - warnings.warn( - f"Failed to register polars DataFrame for table '{table_name}': {exc}", - stacklevel=4, - ) - return False - return True + if hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__"): + self.from_arrow(obj, name=name) + return True + return False def create_dataframe( self, partitions: list[list[pa.RecordBatch]], diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 531cc9f06..68e231200 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -17,6 +17,7 @@ import datetime as dt import gzip import pathlib +from uuid import uuid4 import pyarrow as pa import pyarrow.dataset as ds @@ -589,8 +590,6 @@ def test_table_exist(ctx): def test_table_not_found(ctx): - from uuid import uuid4 - with pytest.raises(KeyError): ctx.table(f"not-found-{uuid4()}") @@ -739,6 +738,43 @@ def test_sql_with_options_no_statements(ctx): ctx.sql_with_options(sql, options=options) +def test_sql_auto_register_pandas(): + pd = pytest.importorskip("pandas") + + ctx = SessionContext() + pdf = pd.DataFrame({"value": [1, 2, 3]}) + assert len(pdf) == 3 + + batches = ctx.sql("SELECT SUM(value) AS total FROM pdf").collect() + assert batches[0].column(0).to_pylist()[0] == 6 + + +def test_sql_auto_register_arrow(): + ctx = SessionContext() + arrow_table = pa.table({"value": [1, 2, 3, 4]}) + assert arrow_table.num_rows == 4 + + batches = ctx.sql("SELECT COUNT(*) AS cnt FROM arrow_table").collect() + assert batches[0].column(0).to_pylist()[0] == 4 + + +def test_sql_auto_register_disabled(): + pd = pytest.importorskip("pandas") + + ctx = SessionContext(auto_register_python_objects=False) + pdf = pd.DataFrame({"value": [1, 2, 3]}) + assert len(pdf) == 3 + + with pytest.raises(Exception) as excinfo: + ctx.sql("SELECT * FROM pdf").collect() + + assert "not found" in str(excinfo.value) + + ctx.set_python_table_lookup(True) + batches = ctx.sql("SELECT COUNT(*) AS cnt FROM pdf").collect() + assert batches[0].column(0).to_pylist()[0] == 3 + + @pytest.fixture def batch(): return pa.RecordBatch.from_arrays( From db2d23920c435b51c5c28b5f3cd5b38b707e3d87 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 18:13:15 +0800 Subject: [PATCH 06/20] feat: add deprecation warnings and alias handling for auto_register_python_variables in SessionContext --- python/datafusion/context.py | 62 ++++++++++++++++++++++++++++++++++-- python/tests/test_context.py | 36 ++++++++++++++++++++- 2 files changed, 94 insertions(+), 4 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index a3c105db0..14a522b1e 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -24,7 +24,7 @@ import re import warnings from functools import cache -from typing import TYPE_CHECKING, Any, Iterator, Protocol +from typing import TYPE_CHECKING, Any, Protocol try: from warnings import deprecated # Python 3.13+ @@ -64,6 +64,13 @@ def _load_optional_module(module_name: str) -> Any | None: return None +_AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED = ( + "SessionContext.auto_register_python_variables is deprecated; use " + "SessionContext.set_python_table_lookup() or the " + "'auto_register_python_objects' keyword argument instead." +) + + class ArrowStreamExportable(Protocol): """Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface. @@ -503,7 +510,8 @@ def __init__( config: SessionConfig | None = None, runtime: RuntimeEnvBuilder | None = None, *, - auto_register_python_objects: bool = True, + auto_register_python_objects: bool | None = None, + auto_register_python_variables: bool | None = None, ) -> None: """Main interface for executing queries with DataFusion. @@ -517,6 +525,9 @@ def __init__( auto_register_python_objects: Automatically register referenced Python objects (such as pandas or PyArrow data) when ``sql`` queries reference them by name. + auto_register_python_variables: Deprecated alias for + ``auto_register_python_objects``. When provided, it overrides + the automatic registration behavior. Example usage: @@ -532,7 +543,33 @@ def __init__( config.config_internal if config is not None else None, runtime.config_internal if runtime is not None else None, ) - self._auto_python_table_lookup = auto_register_python_objects + if auto_register_python_variables is not None: + warnings.warn( + _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, + DeprecationWarning, + stacklevel=2, + ) + + if ( + auto_register_python_objects is not None + and auto_register_python_variables is not None + and auto_register_python_objects != auto_register_python_variables + ): + conflict_message = ( + "auto_register_python_objects and auto_register_python_variables " + "were provided with conflicting values." + ) + raise ValueError(conflict_message) + + if auto_register_python_objects is None: + if auto_register_python_variables is None: + auto_python_table_lookup = True + else: + auto_python_table_lookup = auto_register_python_variables + else: + auto_python_table_lookup = auto_register_python_objects + + self._auto_python_table_lookup = auto_python_table_lookup def __repr__(self) -> str: """Print a string representation of the Session Context.""" @@ -579,6 +616,25 @@ def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: self._auto_python_table_lookup = enabled return self + @property + def auto_register_python_variables(self) -> bool: + """Deprecated alias for :py:meth:`set_python_table_lookup`.""" + warnings.warn( + _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, + DeprecationWarning, + stacklevel=2, + ) + return getattr(self, "_auto_python_table_lookup", True) + + @auto_register_python_variables.setter + def auto_register_python_variables(self, enabled: bool) -> None: + warnings.warn( + _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, + DeprecationWarning, + stacklevel=2, + ) + self.set_python_table_lookup(enabled) + def register_object_store( self, schema: str, store: Any, host: str | None = None ) -> None: diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 68e231200..6fb3dfbfd 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -257,9 +257,10 @@ def test_from_pylist(ctx): def test_sql_missing_table_without_auto_register(ctx): + ctx.set_python_table_lookup(False) arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 - with pytest.raises(Exception, match="not found") as excinfo: + with pytest.raises(Exception, match="not found|No table named") as excinfo: ctx.sql("SELECT * FROM arrow_table").collect() # Test that our extraction method works correctly @@ -319,6 +320,39 @@ def test_sql_auto_register_polars_dataframe(): assert result[0].column(0).to_pylist()[0] == 2 +def test_session_context_constructor_alias_disables_lookup(): + with pytest.deprecated_call(): + ctx = SessionContext(auto_register_python_variables=False) + + arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 + + with pytest.raises(Exception, match="not found|No table named"): + ctx.sql("SELECT * FROM arrow_table").collect() + + with pytest.deprecated_call(): + assert ctx.auto_register_python_variables is False + + +def test_session_context_property_alias_setter_enables_lookup(): + ctx = SessionContext(auto_register_python_objects=False) + arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 + + with pytest.raises(Exception, match="not found|No table named"): + ctx.sql("SELECT COUNT(*) FROM arrow_table").collect() + + with pytest.deprecated_call(): + ctx.auto_register_python_variables = True + + result = ctx.sql( + "SELECT SUM(value) AS total FROM arrow_table", + ).collect() + + assert result[0].column(0).to_pylist()[0] == 6 + + with pytest.deprecated_call(): + assert ctx.auto_register_python_variables is True + + def test_from_pydict(ctx): # create a dataframe from Python dictionary data = {"a": [1, 2, 3], "b": [4, 5, 6]} From 8fc3e1c5184ccf0adfa919a10beae0ee754b195b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 18:18:45 +0800 Subject: [PATCH 07/20] feat: enhance SessionContext to support automatic registration of Python objects via session config --- docs/source/user-guide/dataframe/index.rst | 20 ++++---- python/datafusion/context.py | 56 ++++++++++++---------- python/tests/test_context.py | 8 ++-- 3 files changed, 45 insertions(+), 39 deletions(-) diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index 177b0298d..ced7693f8 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -228,10 +228,10 @@ Core Classes * :py:meth:`~datafusion.SessionContext.from_pandas` - Create from Pandas DataFrame * :py:meth:`~datafusion.SessionContext.from_arrow` - Create from Arrow data - ``SessionContext`` automatically resolves SQL table names that match - in-scope Python data objects. When ``auto_register_python_objects`` is - enabled (the default), a query such as ``ctx.sql("SELECT * FROM pdf")`` - will register a pandas or PyArrow object named ``pdf`` without calling + ``SessionContext`` can automatically resolve SQL table names that match + in-scope Python data objects. When automatic lookup is enabled, a query + such as ``ctx.sql("SELECT * FROM pdf")`` will register a pandas or + PyArrow object named ``pdf`` without calling :py:meth:`~datafusion.SessionContext.from_pandas` or :py:meth:`~datafusion.SessionContext.from_arrow` explicitly. This requires the corresponding library (``pandas`` for pandas objects, ``pyarrow`` for @@ -242,16 +242,18 @@ Core Classes import pandas as pd from datafusion import SessionContext - ctx = SessionContext() + ctx = SessionContext(auto_register_python_objects=True) pdf = pd.DataFrame({"value": [1, 2, 3]}) df = ctx.sql("SELECT SUM(value) AS total FROM pdf") print(df.to_pandas()) # automatically registers `pdf` - To opt out, either pass ``auto_register_python_objects=False`` when - constructing the session, or call - :py:meth:`~datafusion.SessionContext.set_python_table_lookup` with - ``False`` to require explicit registration. + Automatic lookup is disabled by default. Enable it by passing + ``auto_register_python_objects=True`` when constructing the session or by + configuring :py:class:`~datafusion.SessionConfig` with + :py:meth:`~datafusion.SessionConfig.with_python_table_lookup`. Use + :py:meth:`~datafusion.SessionContext.set_python_table_lookup` to toggle the + behaviour at runtime. See: :py:class:`datafusion.SessionContext` diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 14a522b1e..e5df1f0bc 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -524,7 +524,10 @@ def __init__( runtime: Runtime configuration options. auto_register_python_objects: Automatically register referenced Python objects (such as pandas or PyArrow data) when ``sql`` - queries reference them by name. + queries reference them by name. When omitted, this defaults to + the value configured via + :py:meth:`~datafusion.SessionConfig.with_python_table_lookup` + (``False`` unless explicitly enabled). auto_register_python_variables: Deprecated alias for ``auto_register_python_objects``. When provided, it overrides the automatic registration behavior. @@ -543,6 +546,7 @@ def __init__( config.config_internal if config is not None else None, runtime.config_internal if runtime is not None else None, ) + if auto_register_python_variables is not None: warnings.warn( _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, @@ -550,26 +554,26 @@ def __init__( stacklevel=2, ) - if ( - auto_register_python_objects is not None - and auto_register_python_variables is not None - and auto_register_python_objects != auto_register_python_variables - ): - conflict_message = ( - "auto_register_python_objects and auto_register_python_variables " - "were provided with conflicting values." - ) - raise ValueError(conflict_message) + if auto_register_python_variables is not None and auto_register_python_objects is not None: + if auto_register_python_objects != auto_register_python_variables: + conflict_message = ( + "auto_register_python_objects and auto_register_python_variables " + "were provided with conflicting values." + ) + raise ValueError(conflict_message) - if auto_register_python_objects is None: - if auto_register_python_variables is None: - auto_python_table_lookup = True - else: - auto_python_table_lookup = auto_register_python_variables - else: + # Determine the final value for python table lookup + if auto_register_python_objects is not None: auto_python_table_lookup = auto_register_python_objects + elif auto_register_python_variables is not None: + auto_python_table_lookup = auto_register_python_variables + else: + # Default to session config value or False if not configured + auto_python_table_lookup = getattr( + config, "_python_table_lookup", False + ) - self._auto_python_table_lookup = auto_python_table_lookup + self._auto_python_table_lookup = bool(auto_python_table_lookup) def __repr__(self) -> str: """Print a string representation of the Session Context.""" @@ -597,7 +601,7 @@ def enable_url_table(self) -> SessionContext: obj = klass.__new__(klass) obj.ctx = self.ctx.enable_url_table() obj._auto_python_table_lookup = getattr( - self, "_auto_python_table_lookup", True + self, "_auto_python_table_lookup", False ) return obj @@ -605,10 +609,10 @@ def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: """Enable or disable automatic registration of Python objects in SQL. Args: - enabled: When ``True`` (default), SQL queries automatically attempt - to resolve missing table names by looking up Python objects in - the caller's scope. When ``False``, missing tables will raise an - error unless they have been explicitly registered. + enabled: When ``True``, SQL queries automatically attempt to + resolve missing table names by looking up Python objects in the + caller's scope. Use ``False`` to require explicit registration + of any referenced tables. Returns: The current :py:class:`SessionContext` instance for chaining. @@ -624,7 +628,7 @@ def auto_register_python_variables(self) -> bool: DeprecationWarning, stacklevel=2, ) - return getattr(self, "_auto_python_table_lookup", True) + return bool(getattr(self, "_auto_python_table_lookup", False)) @auto_register_python_variables.setter def auto_register_python_variables(self, enabled: bool) -> None: @@ -633,7 +637,7 @@ def auto_register_python_variables(self, enabled: bool) -> None: DeprecationWarning, stacklevel=2, ) - self.set_python_table_lookup(enabled) + self.set_python_table_lookup(bool(enabled)) def register_object_store( self, schema: str, store: Any, host: str | None = None @@ -709,7 +713,7 @@ def _execute_sql() -> DataFrame: try: return _execute_sql() except Exception as err: - if not getattr(self, "_auto_python_table_lookup", True): + if not getattr(self, "_auto_python_table_lookup", False): raise missing_tables = self._extract_missing_table_names(err) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 6fb3dfbfd..5c82a105e 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -772,10 +772,10 @@ def test_sql_with_options_no_statements(ctx): ctx.sql_with_options(sql, options=options) -def test_sql_auto_register_pandas(): +def test_session_config_python_table_lookup_enables_auto_registration(): pd = pytest.importorskip("pandas") - ctx = SessionContext() + ctx = SessionContext(config=SessionConfig().with_python_table_lookup(True)) pdf = pd.DataFrame({"value": [1, 2, 3]}) assert len(pdf) == 3 @@ -784,7 +784,7 @@ def test_sql_auto_register_pandas(): def test_sql_auto_register_arrow(): - ctx = SessionContext() + ctx = SessionContext(auto_register_python_objects=True) arrow_table = pa.table({"value": [1, 2, 3, 4]}) assert arrow_table.num_rows == 4 @@ -795,7 +795,7 @@ def test_sql_auto_register_arrow(): def test_sql_auto_register_disabled(): pd = pytest.importorskip("pandas") - ctx = SessionContext(auto_register_python_objects=False) + ctx = SessionContext() pdf = pd.DataFrame({"value": [1, 2, 3]}) assert len(pdf) == 3 From b733408a0685b3334d10ffc8e050a4129be56f1a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 18:43:10 +0800 Subject: [PATCH 08/20] fix: correct parameter name from auto_register_python_variables to auto_register_python_objects in SessionContext --- docs/source/user-guide/sql.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user-guide/sql.rst b/docs/source/user-guide/sql.rst index 9556bf3e0..d11f277bd 100644 --- a/docs/source/user-guide/sql.rst +++ b/docs/source/user-guide/sql.rst @@ -50,7 +50,7 @@ objects that appear in SQL queries. This removes the need to call import pyarrow as pa from datafusion import SessionContext - ctx = SessionContext(auto_register_python_variables=True) + ctx = SessionContext(auto_register_python_objects=True) orders = pa.Table.from_pydict({"item": ["apple", "pear"], "qty": [5, 2]}) @@ -60,5 +60,5 @@ objects that appear in SQL queries. This removes the need to call The feature inspects the call stack for variables whose names match missing tables and registers them if they expose Arrow data (including pandas and Polars DataFrames). Existing contexts can enable or disable the behavior at -runtime through the :py:attr:`SessionContext.auto_register_python_variables` -property. +runtime through :py:meth:`SessionContext.set_python_table_lookup` or by passing +``auto_register_python_objects`` when constructing the session. From fb3dadbb5a40063ae6e6a2279366fbe9f539be85 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 18:56:28 +0800 Subject: [PATCH 09/20] feat: add normalization for missing table names extraction in SessionContext --- python/datafusion/context.py | 28 ++++++++++++++++++++++------ python/tests/test_context.py | 14 ++++++++++++++ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index e5df1f0bc..3746f1129 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -743,17 +743,33 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: @staticmethod def _extract_missing_table_names(err: Exception) -> list[str]: + def _normalize(names: list[Any]) -> list[str]: + tables: list[str] = [] + for raw_name in names: + if not raw_name: + continue + raw_str = str(raw_name) + tables.append(raw_str.rsplit(".", 1)[-1]) + return tables + + missing_tables = getattr(err, "missing_table_names", None) + if missing_tables is not None: + if isinstance(missing_tables, str): + candidates: list[Any] = [missing_tables] + else: + try: + candidates = list(missing_tables) + except TypeError: + candidates = [missing_tables] + + return _normalize(candidates) + message = str(err) matches = set() for pattern in (r"table '([^']+)' not found", r"No table named '([^']+)'"): matches.update(re.findall(pattern, message)) - tables: list[str] = [] - for raw_name in matches: - if not raw_name: - continue - tables.append(raw_name.rsplit(".", 1)[-1]) - return tables + return _normalize(list(matches)) def _register_python_tables(self, tables: list[str]) -> bool: registered_any = False diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 5c82a105e..916a94eb5 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -268,6 +268,20 @@ def test_sql_missing_table_without_auto_register(ctx): assert "arrow_table" in missing_tables +def test_extract_missing_table_names_from_attribute(): + class MissingTablesError(Exception): + def __init__(self) -> None: + super().__init__("custom error") + self.missing_table_names = ( + "catalog.schema.arrow_table", + "plain_table", + ) + + err = MissingTablesError() + missing_tables = SessionContext._extract_missing_table_names(err) + assert missing_tables == ["arrow_table", "plain_table"] + + def test_sql_auto_register_arrow_table(): ctx = SessionContext(auto_register_python_variables=True) arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 From 6454b8ce39546938d184816dff6fcf00b4c23303 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 19:26:47 +0800 Subject: [PATCH 10/20] test: add unit test for automatic registration of multiple Python tables in SQL queries --- python/datafusion/context.py | 28 ++++++++++++++++------------ python/tests/test_context.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 3746f1129..f20c66cd5 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -710,21 +710,25 @@ def _execute_sql() -> DataFrame: self.ctx.sql_with_options(query, options.options_internal) ) - try: - return _execute_sql() - except Exception as err: - if not getattr(self, "_auto_python_table_lookup", False): - raise + auto_lookup_enabled = getattr(self, "_auto_python_table_lookup", False) + + while True: + try: + return _execute_sql() + except Exception as err: + if not auto_lookup_enabled: + raise - missing_tables = self._extract_missing_table_names(err) - if not missing_tables: - raise + missing_tables = self._extract_missing_table_names(err) + if not missing_tables: + raise - registered = self._register_python_tables(missing_tables) - if not registered: - raise + registered = self._register_python_tables(missing_tables) + if not registered: + raise - return _execute_sql() + # Retry to allow registering additional tables referenced in the query. + continue def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text. diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 916a94eb5..df151ba36 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -294,6 +294,35 @@ def test_sql_auto_register_arrow_table(): assert result[0].column(0).to_pylist()[0] == 6 +def test_sql_auto_register_multiple_tables_single_query(): + ctx = SessionContext(auto_register_python_objects=True) + + customers = pa.Table.from_pydict( # noqa: F841 + {"customer_id": [1, 2], "name": ["Alice", "Bob"]} + ) + orders = pa.Table.from_pydict( # noqa: F841 + {"order_id": [100, 200], "customer_id": [1, 2]} + ) + + result = ctx.sql( + """ + SELECT c.customer_id, o.order_id + FROM customers c + JOIN orders o ON c.customer_id = o.customer_id + ORDER BY o.order_id + """ + ).collect() + + actual = pa.Table.from_batches(result) + expected = pa.Table.from_pydict( + {"customer_id": [1, 2], "order_id": [100, 200]} + ) + + assert actual.equals(expected) + assert ctx.table_exist("customers") + assert ctx.table_exist("orders") + + def test_sql_auto_register_arrow_outer_scope(): ctx = SessionContext() ctx.auto_register_python_variables = True From dc1b3926b13100770a1aa08901d308cca201fefc Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 19:29:12 +0800 Subject: [PATCH 11/20] refactor: clean up unused imports and simplify conditional logic in SessionContext --- python/datafusion/context.py | 39 +++++++++++++----------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index f20c66cd5..683c8a5ea 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,11 +19,9 @@ from __future__ import annotations -import importlib import inspect import re import warnings -from functools import cache from typing import TYPE_CHECKING, Any, Protocol try: @@ -54,16 +52,6 @@ from datafusion.plan import ExecutionPlan, LogicalPlan - -@cache -def _load_optional_module(module_name: str) -> Any | None: - """Return the module for *module_name* if it can be imported.""" - try: - return importlib.import_module(module_name) - except ModuleNotFoundError: - return None - - _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED = ( "SessionContext.auto_register_python_variables is deprecated; use " "SessionContext.set_python_table_lookup() or the " @@ -554,13 +542,16 @@ def __init__( stacklevel=2, ) - if auto_register_python_variables is not None and auto_register_python_objects is not None: - if auto_register_python_objects != auto_register_python_variables: - conflict_message = ( - "auto_register_python_objects and auto_register_python_variables " - "were provided with conflicting values." - ) - raise ValueError(conflict_message) + if ( + auto_register_python_variables is not None + and auto_register_python_objects is not None + and auto_register_python_objects != auto_register_python_variables + ): + conflict_message = ( + "auto_register_python_objects and auto_register_python_variables " + "were provided with conflicting values." + ) + raise ValueError(conflict_message) # Determine the final value for python table lookup if auto_register_python_objects is not None: @@ -569,9 +560,7 @@ def __init__( auto_python_table_lookup = auto_register_python_variables else: # Default to session config value or False if not configured - auto_python_table_lookup = getattr( - config, "_python_table_lookup", False - ) + auto_python_table_lookup = getattr(config, "_python_table_lookup", False) self._auto_python_table_lookup = bool(auto_python_table_lookup) @@ -703,12 +692,11 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame: Returns: DataFrame representation of the SQL query. """ + def _execute_sql() -> DataFrame: if options is None: return DataFrame(self.ctx.sql(query)) - return DataFrame( - self.ctx.sql_with_options(query, options.options_internal) - ) + return DataFrame(self.ctx.sql_with_options(query, options.options_internal)) auto_lookup_enabled = getattr(self, "_auto_python_table_lookup", False) @@ -829,6 +817,7 @@ def _register_python_object(self, name: str, obj: Any) -> bool: return True return False + def create_dataframe( self, partitions: list[list[pa.RecordBatch]], From 904c1ca5e0b883854caae13fcd5f43eb5290bf06 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 20:44:21 +0800 Subject: [PATCH 12/20] feat: add support for automatic registration of Polars DataFrame in SessionContext --- python/datafusion/context.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 683c8a5ea..4ef233320 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -801,6 +801,13 @@ def _register_python_object(self, name: str, obj: Any) -> bool: self.register_view(name, obj) return True + if ( + obj.__class__.__module__.startswith("polars.") + and obj.__class__.__name__ == "DataFrame" + ): + self.from_polars(obj, name=name) + return True + if ( obj.__class__.__module__.startswith("pandas.") and obj.__class__.__name__ == "DataFrame" From 1764a57c8b0f077a0717cef86433a7f7548b382b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 21:26:12 +0800 Subject: [PATCH 13/20] test: add tests for case-insensitive lookup and skipping None shadowing in SQL queries --- python/datafusion/context.py | 27 +++++++++++++++++++++++---- python/tests/test_context.py | 25 +++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 4ef233320..8a6b7bcf3 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -784,13 +784,32 @@ def _lookup_python_object(name: str) -> Any | None: try: if frame is not None: frame = frame.f_back + lower_name = name.lower() + + def _match(mapping: dict[str, Any]) -> Any | None: + if not mapping: + return None + + value = mapping.get(name) + if value is not None: + return value + + for key, candidate in mapping.items(): + if isinstance(key, str) and key.lower() == lower_name: + if candidate is not None: + return candidate + + return None + while frame is not None: locals_dict = frame.f_locals - if name in locals_dict: - return locals_dict[name] + match = _match(locals_dict) + if match is not None: + return match globals_dict = frame.f_globals - if name in globals_dict: - return globals_dict[name] + match = _match(globals_dict) + if match is not None: + return match frame = frame.f_back finally: del frame diff --git a/python/tests/test_context.py b/python/tests/test_context.py index df151ba36..8c8da4c77 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -337,6 +337,31 @@ def run_query(): assert result[0].column(0).to_pylist()[0] == 4 +def test_sql_auto_register_skips_none_shadowing(): + ctx = SessionContext(auto_register_python_objects=True) + mytable = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 + + def run_query(): + mytable = None # noqa: F841 + return ctx.sql( + "SELECT SUM(value) AS total FROM mytable", + ).collect() + + batches = run_query() + assert batches[0].column(0).to_pylist()[0] == 6 + + +def test_sql_auto_register_case_insensitive_lookup(): + ctx = SessionContext(auto_register_python_objects=True) + MyTable = pa.Table.from_pydict({"value": [2, 3]}) # noqa: F841 + + batches = ctx.sql( + "SELECT SUM(value) AS total FROM mytable", + ).collect() + + assert batches[0].column(0).to_pylist()[0] == 5 + + def test_sql_auto_register_pandas_dataframe(): pd = pytest.importorskip("pandas") From b9041ba23f635368c8bc1378f8eb0212f0c15baf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 19 Sep 2025 21:56:01 +0800 Subject: [PATCH 14/20] test: add unit test for refreshing reassigned pandas DataFrame in SQL context --- python/datafusion/context.py | 67 ++++++++++++++++++++++++++++-------- python/tests/test_context.py | 21 +++++++++++ 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 8a6b7bcf3..8b46a306d 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -22,6 +22,7 @@ import inspect import re import warnings +import weakref from typing import TYPE_CHECKING, Any, Protocol try: @@ -563,6 +564,9 @@ def __init__( auto_python_table_lookup = getattr(config, "_python_table_lookup", False) self._auto_python_table_lookup = bool(auto_python_table_lookup) + self._python_table_bindings: dict[ + str, tuple[weakref.ReferenceType[Any] | None, int] + ] = {} def __repr__(self) -> str: """Print a string representation of the Session Context.""" @@ -592,6 +596,9 @@ def enable_url_table(self) -> SessionContext: obj._auto_python_table_lookup = getattr( self, "_auto_python_table_lookup", False ) + obj._python_table_bindings = getattr( + self, "_python_table_bindings", {} + ).copy() return obj def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: @@ -700,10 +707,13 @@ def _execute_sql() -> DataFrame: auto_lookup_enabled = getattr(self, "_auto_python_table_lookup", False) + if auto_lookup_enabled: + self._refresh_python_table_bindings() + while True: try: return _execute_sql() - except Exception as err: + except Exception as err: # noqa: PERF203 if not auto_lookup_enabled: raise @@ -815,34 +825,60 @@ def _match(mapping: dict[str, Any]) -> Any | None: del frame return None + def _refresh_python_table_bindings(self) -> None: + bindings = getattr(self, "_python_table_bindings", {}) + for table_name, (obj_ref, cached_id) in list(bindings.items()): + cached_obj = obj_ref() if obj_ref is not None else None + current_obj = self._lookup_python_object(table_name) + weakref_dead = obj_ref is not None and cached_obj is None + id_mismatch = current_obj is not None and id(current_obj) != cached_id + + if not (weakref_dead or id_mismatch): + continue + + self.deregister_table(table_name) + + if current_obj is None: + bindings.pop(table_name, None) + continue + + if self._register_python_object(table_name, current_obj): + continue + + bindings.pop(table_name, None) + def _register_python_object(self, name: str, obj: Any) -> bool: + registered = False + if isinstance(obj, DataFrame): self.register_view(name, obj) - return True - - if ( + registered = True + elif ( obj.__class__.__module__.startswith("polars.") and obj.__class__.__name__ == "DataFrame" ): self.from_polars(obj, name=name) - return True - - if ( + registered = True + elif ( obj.__class__.__module__.startswith("pandas.") and obj.__class__.__name__ == "DataFrame" ): self.from_pandas(obj, name=name) - return True - - if isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)): + registered = True + elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)) or ( + hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__") + ): self.from_arrow(obj, name=name) - return True + registered = True - if hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__"): - self.from_arrow(obj, name=name) - return True + if registered: + try: + reference: weakref.ReferenceType[Any] | None = weakref.ref(obj) + except TypeError: + reference = None + self._python_table_bindings[name] = (reference, id(obj)) - return False + return registered def create_dataframe( self, @@ -981,6 +1017,7 @@ def register_table(self, name: str, table: Table) -> None: def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + self._python_table_bindings.pop(name, None) def catalog_names(self) -> set[str]: """Returns the list of catalogs in this context.""" diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 8c8da4c77..b1ee548ab 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -375,6 +375,27 @@ def test_sql_auto_register_pandas_dataframe(): assert pytest.approx(result[0].column(0).to_pylist()[0]) == 2.5 +def test_sql_auto_register_refreshes_reassigned_dataframe(): + pd = pytest.importorskip("pandas") + + ctx = SessionContext(auto_register_python_objects=True) + pandas_df = pd.DataFrame({"value": [1, 2, 3]}) + + first = ctx.sql( + "SELECT SUM(value) AS total FROM pandas_df", + ).collect() + + assert first[0].column(0).to_pylist()[0] == 6 + + pandas_df = pd.DataFrame({"value": [10, 20]}) # noqa: F841 + + second = ctx.sql( + "SELECT SUM(value) AS total FROM pandas_df", + ).collect() + + assert second[0].column(0).to_pylist()[0] == 30 + + def test_sql_auto_register_polars_dataframe(): pl = pytest.importorskip("polars") From ac1d6e12e8552d3527de02690e957dbceb49f5a6 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 21 Sep 2025 17:23:42 +0800 Subject: [PATCH 15/20] feat: enhance error handling for missing tables in SQL queries --- python/tests/test_context.py | 12 ++++++++++++ src/context.rs | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index b1ee548ab..e4ae61583 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -268,6 +268,18 @@ def test_sql_missing_table_without_auto_register(ctx): assert "arrow_table" in missing_tables +def test_sql_missing_table_exposes_missing_table_names(ctx): + ctx.set_python_table_lookup(False) + + with pytest.raises(Exception) as excinfo: + ctx.sql("SELECT * FROM missing_table").collect() + + missing_tables = getattr(excinfo.value, "missing_table_names", None) + assert missing_tables is not None + normalized = [str(name).rsplit(".", 1)[-1] for name in missing_tables] + assert normalized == ["missing_table"] + + def test_extract_missing_table_names_from_attribute(): class MissingTablesError(Exception): def __init__(self) -> None: diff --git a/src/context.rs b/src/context.rs index 9dd9db37a..9af58ac39 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1237,11 +1237,15 @@ fn collect_missing_table_names_recursive(err: &DataFusionError, acc: &mut HashSe } fn parse_missing_table_names_in_message(message: &str, acc: &mut HashSet) { - const LOOKUPS: [(&str, char); 4] = [ + const LOOKUPS: [(&str, char); 8] = [ ("table '", '\''), ("view '", '\''), ("table \"", '"'), ("view \"", '"'), + ("table named '", '\''), + ("view named '", '\''), + ("table named \"", '"'), + ("view named \"", '"'), ]; let lower = message.to_ascii_lowercase(); From 15b5cecbf90d3414f25a5ed222a1349e5c33281a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 21 Sep 2025 18:09:45 +0800 Subject: [PATCH 16/20] refactor: replace auto_register_python_variables with auto_register_python_objects in SessionContext --- python/datafusion/context.py | 80 ++++++------------------------------ python/tests/test_context.py | 57 +++++-------------------- 2 files changed, 23 insertions(+), 114 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 8b46a306d..75e4d7a09 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -53,12 +53,6 @@ from datafusion.plan import ExecutionPlan, LogicalPlan -_AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED = ( - "SessionContext.auto_register_python_variables is deprecated; use " - "SessionContext.set_python_table_lookup() or the " - "'auto_register_python_objects' keyword argument instead." -) - class ArrowStreamExportable(Protocol): """Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface. @@ -500,7 +494,6 @@ def __init__( runtime: RuntimeEnvBuilder | None = None, *, auto_register_python_objects: bool | None = None, - auto_register_python_variables: bool | None = None, ) -> None: """Main interface for executing queries with DataFusion. @@ -517,9 +510,6 @@ def __init__( the value configured via :py:meth:`~datafusion.SessionConfig.with_python_table_lookup` (``False`` unless explicitly enabled). - auto_register_python_variables: Deprecated alias for - ``auto_register_python_objects``. When provided, it overrides - the automatic registration behavior. Example usage: @@ -536,29 +526,9 @@ def __init__( runtime.config_internal if runtime is not None else None, ) - if auto_register_python_variables is not None: - warnings.warn( - _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, - DeprecationWarning, - stacklevel=2, - ) - - if ( - auto_register_python_variables is not None - and auto_register_python_objects is not None - and auto_register_python_objects != auto_register_python_variables - ): - conflict_message = ( - "auto_register_python_objects and auto_register_python_variables " - "were provided with conflicting values." - ) - raise ValueError(conflict_message) - # Determine the final value for python table lookup if auto_register_python_objects is not None: auto_python_table_lookup = auto_register_python_objects - elif auto_register_python_variables is not None: - auto_python_table_lookup = auto_register_python_variables else: # Default to session config value or False if not configured auto_python_table_lookup = getattr(config, "_python_table_lookup", False) @@ -596,9 +566,7 @@ def enable_url_table(self) -> SessionContext: obj._auto_python_table_lookup = getattr( self, "_auto_python_table_lookup", False ) - obj._python_table_bindings = getattr( - self, "_python_table_bindings", {} - ).copy() + obj._python_table_bindings = getattr(self, "_python_table_bindings", {}).copy() return obj def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: @@ -616,25 +584,6 @@ def set_python_table_lookup(self, enabled: bool = True) -> SessionContext: self._auto_python_table_lookup = enabled return self - @property - def auto_register_python_variables(self) -> bool: - """Deprecated alias for :py:meth:`set_python_table_lookup`.""" - warnings.warn( - _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, - DeprecationWarning, - stacklevel=2, - ) - return bool(getattr(self, "_auto_python_table_lookup", False)) - - @auto_register_python_variables.setter - def auto_register_python_variables(self, enabled: bool) -> None: - warnings.warn( - _AUTO_REGISTER_PYTHON_VARIABLES_DEPRECATED, - DeprecationWarning, - stacklevel=2, - ) - self.set_python_table_lookup(bool(enabled)) - def register_object_store( self, schema: str, store: Any, host: str | None = None ) -> None: @@ -792,34 +741,29 @@ def _register_python_tables(self, tables: list[str]) -> bool: def _lookup_python_object(name: str) -> Any | None: frame = inspect.currentframe() try: - if frame is not None: - frame = frame.f_back + frame = frame.f_back if frame is not None else None lower_name = name.lower() def _match(mapping: dict[str, Any]) -> Any | None: - if not mapping: - return None - value = mapping.get(name) if value is not None: return value for key, candidate in mapping.items(): - if isinstance(key, str) and key.lower() == lower_name: - if candidate is not None: - return candidate + if ( + isinstance(key, str) + and key.lower() == lower_name + and candidate is not None + ): + return candidate return None while frame is not None: - locals_dict = frame.f_locals - match = _match(locals_dict) - if match is not None: - return match - globals_dict = frame.f_globals - match = _match(globals_dict) - if match is not None: - return match + for scope in (frame.f_locals, frame.f_globals): + match = _match(scope) + if match is not None: + return match frame = frame.f_back finally: del frame diff --git a/python/tests/test_context.py b/python/tests/test_context.py index e4ae61583..834e41b2b 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -295,7 +295,7 @@ def __init__(self) -> None: def test_sql_auto_register_arrow_table(): - ctx = SessionContext(auto_register_python_variables=True) + ctx = SessionContext(auto_register_python_objects=True) arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 result = ctx.sql( @@ -326,9 +326,7 @@ def test_sql_auto_register_multiple_tables_single_query(): ).collect() actual = pa.Table.from_batches(result) - expected = pa.Table.from_pydict( - {"customer_id": [1, 2], "order_id": [100, 200]} - ) + expected = pa.Table.from_pydict({"customer_id": [1, 2], "order_id": [100, 200]}) assert actual.equals(expected) assert ctx.table_exist("customers") @@ -337,7 +335,7 @@ def test_sql_auto_register_multiple_tables_single_query(): def test_sql_auto_register_arrow_outer_scope(): ctx = SessionContext() - ctx.auto_register_python_variables = True + ctx.set_python_table_lookup(True) arrow_table = pa.Table.from_pydict({"value": [1, 2, 3, 4]}) # noqa: F841 def run_query(): @@ -365,7 +363,7 @@ def run_query(): def test_sql_auto_register_case_insensitive_lookup(): ctx = SessionContext(auto_register_python_objects=True) - MyTable = pa.Table.from_pydict({"value": [2, 3]}) # noqa: F841 + MyTable = pa.Table.from_pydict({"value": [2, 3]}) # noqa: N806,F841 batches = ctx.sql( "SELECT SUM(value) AS total FROM mytable", @@ -377,7 +375,7 @@ def test_sql_auto_register_case_insensitive_lookup(): def test_sql_auto_register_pandas_dataframe(): pd = pytest.importorskip("pandas") - ctx = SessionContext(auto_register_python_variables=True) + ctx = SessionContext(auto_register_python_objects=True) pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) # noqa: F841 result = ctx.sql( @@ -411,7 +409,7 @@ def test_sql_auto_register_refreshes_reassigned_dataframe(): def test_sql_auto_register_polars_dataframe(): pl = pytest.importorskip("polars") - ctx = SessionContext(auto_register_python_variables=True) + ctx = SessionContext(auto_register_python_objects=True) polars_df = pl.DataFrame({"value": [2, 4, 6]}) # noqa: F841 result = ctx.sql( @@ -421,39 +419,6 @@ def test_sql_auto_register_polars_dataframe(): assert result[0].column(0).to_pylist()[0] == 2 -def test_session_context_constructor_alias_disables_lookup(): - with pytest.deprecated_call(): - ctx = SessionContext(auto_register_python_variables=False) - - arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 - - with pytest.raises(Exception, match="not found|No table named"): - ctx.sql("SELECT * FROM arrow_table").collect() - - with pytest.deprecated_call(): - assert ctx.auto_register_python_variables is False - - -def test_session_context_property_alias_setter_enables_lookup(): - ctx = SessionContext(auto_register_python_objects=False) - arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841 - - with pytest.raises(Exception, match="not found|No table named"): - ctx.sql("SELECT COUNT(*) FROM arrow_table").collect() - - with pytest.deprecated_call(): - ctx.auto_register_python_variables = True - - result = ctx.sql( - "SELECT SUM(value) AS total FROM arrow_table", - ).collect() - - assert result[0].column(0).to_pylist()[0] == 6 - - with pytest.deprecated_call(): - assert ctx.auto_register_python_variables is True - - def test_from_pydict(ctx): # create a dataframe from Python dictionary data = {"a": [1, 2, 3], "b": [4, 5, 6]} @@ -486,7 +451,7 @@ def test_from_pandas(ctx): def test_sql_from_local_arrow_table(ctx): ctx.set_python_table_lookup(True) # Enable implicit table lookup - arrow_table = pa.Table.from_pydict({"a": [1, 2], "b": ["x", "y"]}) + arrow_table = pa.Table.from_pydict({"a": [1, 2], "b": ["x", "y"]}) # noqa: F841 result = ctx.sql("SELECT * FROM arrow_table ORDER BY a").collect() actual = pa.Table.from_batches(result) @@ -498,7 +463,7 @@ def test_sql_from_local_arrow_table(ctx): def test_sql_from_local_pandas_dataframe(ctx): ctx.set_python_table_lookup(True) # Enable implicit table lookup pd = pytest.importorskip("pandas") - pandas_df = pd.DataFrame({"a": [3, 1], "b": ["z", "y"]}) + pandas_df = pd.DataFrame({"a": [3, 1], "b": ["z", "y"]}) # noqa: F841 result = ctx.sql("SELECT * FROM pandas_df ORDER BY a").collect() actual = pa.Table.from_batches(result) @@ -510,7 +475,7 @@ def test_sql_from_local_pandas_dataframe(ctx): def test_sql_from_local_polars_dataframe(ctx): ctx.set_python_table_lookup(True) # Enable implicit table lookup pl = pytest.importorskip("polars") - polars_df = pl.DataFrame({"a": [2, 1], "b": ["beta", "alpha"]}) + polars_df = pl.DataFrame({"a": [2, 1], "b": ["beta", "alpha"]}) # noqa: F841 result = ctx.sql("SELECT * FROM polars_df ORDER BY a").collect() actual = pa.Table.from_batches(result) @@ -520,7 +485,7 @@ def test_sql_from_local_polars_dataframe(ctx): def test_sql_from_local_unsupported_object(ctx): - unsupported = object() + unsupported = object() # noqa: F841 with pytest.raises(Exception, match="table 'unsupported' not found"): ctx.sql("SELECT * FROM unsupported").collect() @@ -876,7 +841,7 @@ def test_sql_with_options_no_statements(ctx): def test_session_config_python_table_lookup_enables_auto_registration(): pd = pytest.importorskip("pandas") - ctx = SessionContext(config=SessionConfig().with_python_table_lookup(True)) + ctx = SessionContext(config=SessionConfig().with_python_table_lookup(enabled=True)) pdf = pd.DataFrame({"value": [1, 2, 3]}) assert len(pdf) == 3 From dc0687410faa8531d45260d67ae65ffb5bbd998a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 23 Sep 2025 09:57:47 +0800 Subject: [PATCH 17/20] docs: clarify automatic registration of pandas and pyarrow objects in SessionContext --- docs/source/user-guide/dataframe/index.rst | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/docs/source/user-guide/dataframe/index.rst b/docs/source/user-guide/dataframe/index.rst index ced7693f8..74e1b7586 100644 --- a/docs/source/user-guide/dataframe/index.rst +++ b/docs/source/user-guide/dataframe/index.rst @@ -233,20 +233,30 @@ Core Classes such as ``ctx.sql("SELECT * FROM pdf")`` will register a pandas or PyArrow object named ``pdf`` without calling :py:meth:`~datafusion.SessionContext.from_pandas` or - :py:meth:`~datafusion.SessionContext.from_arrow` explicitly. This requires - the corresponding library (``pandas`` for pandas objects, ``pyarrow`` for - Arrow objects) to be installed. + :py:meth:`~datafusion.SessionContext.from_arrow` explicitly. This uses + the Arrow PyCapsule Interface, so the corresponding library (``pandas`` + for pandas objects, ``pyarrow`` for Arrow objects) must be installed. .. code-block:: python import pandas as pd + import pyarrow as pa from datafusion import SessionContext ctx = SessionContext(auto_register_python_objects=True) + + # pandas dataframe - requires pandas to be installed pdf = pd.DataFrame({"value": [1, 2, 3]}) + + # or pyarrow object - requires pyarrow to be installed + arrow_table = pa.table({"value": [1, 2, 3]}) + # If automatic registration is enabled, then we can query these objects directly df = ctx.sql("SELECT SUM(value) AS total FROM pdf") - print(df.to_pandas()) # automatically registers `pdf` + # or + df = ctx.sql("SELECT SUM(value) AS total FROM arrow_table") + + # without calling ctx.from_pandas() or ctx.from_arrow() explicitly Automatic lookup is disabled by default. Enable it by passing ``auto_register_python_objects=True`` when constructing the session or by From 78c26ccc66508c547796c2839a04fa64e77b364e Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 23 Sep 2025 10:51:29 +0800 Subject: [PATCH 18/20] refactor: improve auto-registration logic for Arrow and DataFrame objects in SessionContext --- python/datafusion/context.py | 36 +++++++++++++++++++++--------------- python/tests/test_context.py | 13 ++++++++++++- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 75e4d7a09..5d8440bd4 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -797,23 +797,29 @@ def _register_python_object(self, name: str, obj: Any) -> bool: if isinstance(obj, DataFrame): self.register_view(name, obj) registered = True - elif ( - obj.__class__.__module__.startswith("polars.") - and obj.__class__.__name__ == "DataFrame" - ): - self.from_polars(obj, name=name) - registered = True - elif ( - obj.__class__.__module__.startswith("pandas.") - and obj.__class__.__name__ == "DataFrame" - ): - self.from_pandas(obj, name=name) - registered = True - elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)) or ( - hasattr(obj, "__arrow_c_stream__") or hasattr(obj, "__arrow_c_array__") - ): + elif isinstance(obj, (pa.Table, pa.RecordBatch, pa.RecordBatchReader)): self.from_arrow(obj, name=name) registered = True + else: + exports_arrow_capsule = hasattr(obj, "__arrow_c_stream__") or hasattr( + obj, "__arrow_c_array__" + ) + + if exports_arrow_capsule: + self.from_arrow(obj, name=name) + registered = True + elif ( + obj.__class__.__module__.startswith("polars.") + and obj.__class__.__name__ == "DataFrame" + ): + self.from_polars(obj, name=name) + registered = True + elif ( + obj.__class__.__module__.startswith("pandas.") + and obj.__class__.__name__ == "DataFrame" + ): + self.from_pandas(obj, name=name) + registered = True if registered: try: diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 834e41b2b..815896ce1 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -372,12 +372,23 @@ def test_sql_auto_register_case_insensitive_lookup(): assert batches[0].column(0).to_pylist()[0] == 5 -def test_sql_auto_register_pandas_dataframe(): +def test_sql_auto_register_pandas_dataframe(monkeypatch): pd = pytest.importorskip("pandas") ctx = SessionContext(auto_register_python_objects=True) pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) # noqa: F841 + if not ( + hasattr(pandas_df, "__arrow_c_stream__") + or hasattr(pandas_df, "__arrow_c_array__") + ): + pytest.skip("pandas does not expose Arrow capsule export") + + def fail_from_pandas(*args, **kwargs): # noqa: ANN002, ANN003 + raise AssertionError("from_pandas should not be called during auto-registration") + + monkeypatch.setattr(SessionContext, "from_pandas", fail_from_pandas) + result = ctx.sql( "SELECT AVG(value) AS avg_value FROM pandas_df", ).collect() From 57d6380013e0595281171594d5553ab21335ea5a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 23 Sep 2025 10:55:02 +0800 Subject: [PATCH 19/20] fix(tests): remove unused variable warning in test_sql_auto_register_pandas_dataframe --- python/tests/test_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index 815896ce1..b8a71a540 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -376,7 +376,7 @@ def test_sql_auto_register_pandas_dataframe(monkeypatch): pd = pytest.importorskip("pandas") ctx = SessionContext(auto_register_python_objects=True) - pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) # noqa: F841 + pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) if not ( hasattr(pandas_df, "__arrow_c_stream__") @@ -384,7 +384,7 @@ def test_sql_auto_register_pandas_dataframe(monkeypatch): ): pytest.skip("pandas does not expose Arrow capsule export") - def fail_from_pandas(*args, **kwargs): # noqa: ANN002, ANN003 + def fail_from_pandas(*args, **kwargs): raise AssertionError("from_pandas should not be called during auto-registration") monkeypatch.setattr(SessionContext, "from_pandas", fail_from_pandas) From 1a1a5b44841dea2f47b890033a747d029b775e2b Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 23 Sep 2025 10:55:52 +0800 Subject: [PATCH 20/20] Fix Ruff errors --- python/tests/test_context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_context.py b/python/tests/test_context.py index b8a71a540..60f8fe132 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -385,7 +385,8 @@ def test_sql_auto_register_pandas_dataframe(monkeypatch): pytest.skip("pandas does not expose Arrow capsule export") def fail_from_pandas(*args, **kwargs): - raise AssertionError("from_pandas should not be called during auto-registration") + msg = "from_pandas should not be called during auto-registration" + raise AssertionError(msg) monkeypatch.setattr(SessionContext, "from_pandas", fail_from_pandas)