1919
2020from __future__ import annotations
2121
22+ import importlib
2223import inspect
2324import re
2425import warnings
26+ from functools import cache
2527from typing import TYPE_CHECKING , Any , Iterator , Protocol
2628
2729try :
5557 from datafusion .plan import ExecutionPlan , LogicalPlan
5658
5759
60+ @cache
61+ def _load_optional_module (module_name : str ) -> Any | None :
62+ """Return the module for *module_name* if it can be imported."""
63+ try :
64+ return importlib .import_module (module_name )
65+ except ModuleNotFoundError :
66+ return None
67+
68+
5869class ArrowStreamExportable (Protocol ):
5970 """Type hint for object exporting Arrow C Stream via Arrow PyCapsule Interface.
6071
@@ -105,6 +116,7 @@ def __init__(self, config_options: dict[str, str] | None = None) -> None:
105116 config_options: Configuration options.
106117 """
107118 self .config_internal = SessionConfigInternal (config_options )
119+ self ._python_table_lookup = False
108120
109121 def with_create_default_catalog_and_schema (
110122 self , enabled : bool = True
@@ -274,6 +286,11 @@ def with_parquet_pruning(self, enabled: bool = True) -> SessionConfig:
274286 self .config_internal = self .config_internal .with_parquet_pruning (enabled )
275287 return self
276288
289+ def with_python_table_lookup (self , enabled : bool = True ) -> SessionConfig :
290+ """Enable implicit table lookup for Python objects when running SQL."""
291+ self ._python_table_lookup = enabled
292+ return self
293+
277294 def set (self , key : str , value : str ) -> SessionConfig :
278295 """Set a configuration option.
279296
@@ -513,11 +530,17 @@ def __init__(
513530 ctx = SessionContext()
514531 df = ctx.read_csv("data.csv")
515532 """
516- config = config .config_internal if config is not None else None
517- runtime = runtime .config_internal if runtime is not None else None
533+ python_table_lookup = auto_register_python_variables # Use parameter as default
534+ if config is not None :
535+ python_table_lookup = config ._python_table_lookup
536+ config_internal = config .config_internal
537+ else :
538+ config_internal = None
539+
540+ runtime_internal = runtime .config_internal if runtime is not None else None
518541
519- self .ctx = SessionContextInternal (config , runtime )
520- self ._auto_register_python_variables = auto_register_python_variables
542+ self .ctx = SessionContextInternal (config_internal , runtime_internal )
543+ self ._python_table_lookup = python_table_lookup
521544
522545 def __repr__ (self ) -> str :
523546 """Print a string representation of the Session Context."""
@@ -544,17 +567,27 @@ def enable_url_table(self) -> SessionContext:
544567 klass = self .__class__
545568 obj = klass .__new__ (klass )
546569 obj .ctx = self .ctx .enable_url_table ()
547- obj ._auto_register_python_variables = self ._auto_register_python_variables
570+ obj ._python_table_lookup = self ._python_table_lookup
548571 return obj
549572
573+ def set_python_table_lookup (self , enabled : bool ) -> None :
574+ """Enable or disable implicit table lookup for Python objects."""
575+ self ._python_table_lookup = enabled
576+
577+ # Backward compatibility properties
550578 @property
551579 def auto_register_python_variables (self ) -> bool :
552580 """Toggle automatic registration of Python variables in SQL queries."""
553- return self ._auto_register_python_variables
581+ return self ._python_table_lookup
554582
555583 @auto_register_python_variables .setter
556584 def auto_register_python_variables (self , enabled : bool ) -> None :
557- self ._auto_register_python_variables = bool (enabled )
585+ self ._python_table_lookup = bool (enabled )
586+
587+ def _extract_missing_table_names (self , error : Exception ) -> set [str ]:
588+ """Extract missing table names from error (backward compatibility)."""
589+ missing_table = self ._extract_missing_table_name (error )
590+ return {missing_table } if missing_table else set ()
558591
559592 def register_object_store (
560593 self , schema : str , store : Any , host : str | None = None
@@ -620,12 +653,29 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
620653 Returns:
621654 DataFrame representation of the SQL query.
622655 """
623- options_internal = None if options is None else options .options_internal
624- return self ._sql_with_retry (
625- query ,
626- options_internal ,
627- self ._auto_register_python_variables ,
628- )
656+ attempted_missing_tables : set [str ] = set ()
657+
658+ while True :
659+ try :
660+ if options is None :
661+ result = self .ctx .sql (query )
662+ else :
663+ result = self .ctx .sql_with_options (query , options .options_internal )
664+ except Exception as exc :
665+ missing_table = self ._extract_missing_table_name (exc )
666+ if (
667+ missing_table is None
668+ or missing_table in attempted_missing_tables
669+ or not self ._python_table_lookup
670+ ):
671+ raise
672+
673+ attempted_missing_tables .add (missing_table )
674+ if not self ._register_missing_table_from_callers (missing_table ):
675+ raise
676+ continue
677+
678+ return DataFrame (result )
629679
630680 def sql_with_options (self , query : str , options : SQLOptions ) -> DataFrame :
631681 """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
@@ -642,137 +692,122 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
642692 """
643693 return self .sql (query , options )
644694
645- def _sql_with_retry (
646- self ,
647- query : str ,
648- options_internal : SQLOptionsInternal | None ,
649- allow_retry : bool ,
650- ) -> DataFrame :
651- try :
652- if options_internal is None :
653- return DataFrame (self .ctx .sql (query ))
654- return DataFrame (self .ctx .sql_with_options (query , options_internal ))
655- except Exception as exc :
656- if not allow_retry or not self ._handle_missing_table_error (exc ):
657- raise
658- return self ._sql_with_retry (query , options_internal , allow_retry )
659-
660- def _handle_missing_table_error (self , error : Exception ) -> bool :
661- missing_tables = self ._extract_missing_table_names (error )
662- if not missing_tables :
663- return False
664-
665- registered_any = False
666- attempted : set [str ] = set ()
667- for raw_name in missing_tables :
668- for candidate in self ._candidate_table_names (raw_name ):
669- if candidate in attempted :
670- continue
671- attempted .add (candidate )
672-
673- value = self ._lookup_python_variable (candidate )
674- if value is None :
675- continue
676- if self ._register_python_value (candidate , value ):
677- registered_any = True
678- break
679- return registered_any
680-
681- def _candidate_table_names (self , identifier : str ) -> Iterator [str ]:
682- cleaned = identifier .strip ().strip ('"' )
683- if not cleaned :
684- return
685-
686- seen : set [str ] = set ()
687- candidates = [cleaned ]
688- if "." in cleaned :
689- candidates .append (cleaned .rsplit ("." , 1 )[- 1 ])
690-
691- for candidate in candidates :
692- normalized = candidate .strip ()
693- if not normalized or normalized in seen :
694- continue
695- seen .add (normalized )
696- yield normalized
697-
698- def _extract_missing_table_names (self , error : Exception ) -> set [str ]:
699- names : set [str ] = set ()
700- attribute = getattr (error , "missing_table_names" , None )
701- if attribute is not None :
702- if isinstance (attribute , (list , tuple , set , frozenset )):
703- for item in attribute :
704- if item is None :
705- continue
706- for candidate in self ._candidate_table_names (str (item )):
707- names .add (candidate )
708- elif attribute is not None :
709- for candidate in self ._candidate_table_names (str (attribute )):
710- names .add (candidate )
711- if names :
712- return names
713-
695+ @staticmethod
696+ def _extract_missing_table_name (error : Exception ) -> str | None :
714697 message = str (error )
715- return {match .group (1 ) for match in _MISSING_TABLE_PATTERN .finditer (message )}
698+ patterns = (
699+ r"table '([^']+)' not found" ,
700+ r"Table not found: ['\"]?([^\s'\"]+)['\"]?" ,
701+ r"Table or CTE with name ['\"]?([^\s'\"]+)['\"]? not found" ,
702+ r"Invalid reference to table ['\"]?([^\s'\"]+)['\"]?" ,
703+ )
704+ for pattern in patterns :
705+ if match := re .search (pattern , message ):
706+ return match .group (1 )
707+ return None
716708
717- def _lookup_python_variable (self , name : str ) -> Any | None :
709+ def _register_missing_table_from_callers (self , table_name : str ) -> bool :
718710 frame = inspect .currentframe ()
719- outer = frame . f_back if frame is not None else None
720- lower_name = name . lower ()
711+ if frame is None :
712+ return False
721713
722714 try :
723- while outer is not None :
724- for mapping in (outer .f_locals , outer .f_globals ):
725- if not mapping :
726- continue
727- if name in mapping :
728- value = mapping [name ]
729- if value is not None :
730- return value
731- # allow outer scopes to provide a non-``None`` value
732- continue
733- for key , value in mapping .items ():
734- if value is None :
735- continue
736- if key == name or key .lower () == lower_name :
737- return value
738- outer = outer .f_back
715+ frame = frame .f_back
716+ if frame is None :
717+ return False
718+ frame = frame .f_back
719+ while frame is not None :
720+ if self ._register_from_namespace (table_name , frame .f_locals ):
721+ return True
722+ if self ._register_from_namespace (table_name , frame .f_globals ):
723+ return True
724+ frame = frame .f_back
739725 finally :
740- del outer
741726 del frame
742- return None
727+ return False
743728
744- def _register_python_value (self , table_name : str , value : Any ) -> bool :
745- if value is None :
729+ def _register_from_namespace (
730+ self , table_name : str , namespace : dict [str , Any ]
731+ ) -> bool :
732+ if table_name not in namespace :
746733 return False
734+ value = namespace [table_name ]
735+ return self ._register_python_value (table_name , value )
736+
737+ def _register_python_value (self , table_name : str , value : Any ) -> bool :
738+ pandas = _load_optional_module ("pandas" )
739+ polars = _load_optional_module ("polars" )
740+ polars_df = getattr (polars , "DataFrame" , None ) if polars is not None else None
741+
742+ handlers = (
743+ (isinstance (value , DataFrame ), self ._register_datafusion_dataframe ),
744+ (
745+ isinstance (value , (pa .Table , pa .RecordBatch , pa .RecordBatchReader )),
746+ self ._register_arrow_object ,
747+ ),
748+ (
749+ pandas is not None and isinstance (value , pandas .DataFrame ),
750+ self ._register_pandas_dataframe ,
751+ ),
752+ (
753+ polars_df is not None and isinstance (value , polars_df ),
754+ self ._register_polars_dataframe ,
755+ ),
756+ )
757+
758+ for matches , handler in handlers :
759+ if matches :
760+ return handler (table_name , value )
761+
762+ return False
747763
748- registered = False
749- if isinstance ( value , DataFrame ) :
764+ def _register_datafusion_dataframe ( self , table_name : str , value : DataFrame ) -> bool :
765+ try :
750766 self .register_view (table_name , value )
751- registered = True
752- elif isinstance (value , Table ):
753- self .register_table (table_name , value )
754- registered = True
755- else :
756- provider = getattr (value , "__datafusion_table_provider__" , None )
757- if callable (provider ):
758- self .register_table_provider (table_name , value )
759- registered = True
760- elif hasattr (value , "__arrow_c_stream__" ) or hasattr (
761- value , "__arrow_c_array__"
762- ):
763- self .from_arrow (value , name = table_name )
764- registered = True
765- else :
766- module_name = getattr (type (value ), "__module__" , "" ) or ""
767- class_name = getattr (type (value ), "__name__" , "" ) or ""
768- if module_name .startswith ("pandas." ) and class_name == "DataFrame" :
769- self .from_pandas (value , name = table_name )
770- registered = True
771- elif module_name .startswith ("polars" ) and class_name == "DataFrame" :
772- self .from_polars (value , name = table_name )
773- registered = True
774-
775- return registered
767+ except Exception as exc : # noqa: BLE001
768+ warnings .warn (
769+ "Failed to register DataFusion DataFrame for table "
770+ f"'{ table_name } ': { exc } " ,
771+ stacklevel = 4 ,
772+ )
773+ return False
774+ return True
775+
776+ def _register_arrow_object (self , table_name : str , value : Any ) -> bool :
777+ try :
778+ self .from_arrow (value , table_name )
779+ except Exception as exc : # noqa: BLE001
780+ warnings .warn (
781+ "Failed to register Arrow data for table "
782+ f"'{ table_name } ': { exc } " ,
783+ stacklevel = 4 ,
784+ )
785+ return False
786+ return True
787+
788+ def _register_pandas_dataframe (self , table_name : str , value : Any ) -> bool :
789+ try :
790+ self .from_pandas (value , table_name )
791+ except Exception as exc : # noqa: BLE001
792+ warnings .warn (
793+ "Failed to register pandas DataFrame for table "
794+ f"'{ table_name } ': { exc } " ,
795+ stacklevel = 4 ,
796+ )
797+ return False
798+ return True
799+
800+ def _register_polars_dataframe (self , table_name : str , value : Any ) -> bool :
801+ try :
802+ self .from_polars (value , table_name )
803+ except Exception as exc : # noqa: BLE001
804+ warnings .warn (
805+ "Failed to register polars DataFrame for table "
806+ f"'{ table_name } ': { exc } " ,
807+ stacklevel = 4 ,
808+ )
809+ return False
810+ return True
776811
777812 def create_dataframe (
778813 self ,
0 commit comments