Skip to content

Commit 19f6911

Browse files
committed
feat: add normalization for missing table names extraction in SessionContext
1 parent cb5329a commit 19f6911

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

python/datafusion/context.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -743,17 +743,33 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
743743

744744
@staticmethod
745745
def _extract_missing_table_names(err: Exception) -> list[str]:
746+
def _normalize(names: list[Any]) -> list[str]:
747+
tables: list[str] = []
748+
for raw_name in names:
749+
if not raw_name:
750+
continue
751+
raw_str = str(raw_name)
752+
tables.append(raw_str.rsplit(".", 1)[-1])
753+
return tables
754+
755+
missing_tables = getattr(err, "missing_table_names", None)
756+
if missing_tables is not None:
757+
if isinstance(missing_tables, str):
758+
candidates: list[Any] = [missing_tables]
759+
else:
760+
try:
761+
candidates = list(missing_tables)
762+
except TypeError:
763+
candidates = [missing_tables]
764+
765+
return _normalize(candidates)
766+
746767
message = str(err)
747768
matches = set()
748769
for pattern in (r"table '([^']+)' not found", r"No table named '([^']+)'"):
749770
matches.update(re.findall(pattern, message))
750771

751-
tables: list[str] = []
752-
for raw_name in matches:
753-
if not raw_name:
754-
continue
755-
tables.append(raw_name.rsplit(".", 1)[-1])
756-
return tables
772+
return _normalize(list(matches))
757773

758774
def _register_python_tables(self, tables: list[str]) -> bool:
759775
registered_any = False

python/tests/test_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,20 @@ def test_sql_missing_table_without_auto_register(ctx):
268268
assert "arrow_table" in missing_tables
269269

270270

271+
def test_extract_missing_table_names_from_attribute():
272+
class MissingTablesError(Exception):
273+
def __init__(self) -> None:
274+
super().__init__("custom error")
275+
self.missing_table_names = (
276+
"catalog.schema.arrow_table",
277+
"plain_table",
278+
)
279+
280+
err = MissingTablesError()
281+
missing_tables = SessionContext._extract_missing_table_names(err)
282+
assert missing_tables == ["arrow_table", "plain_table"]
283+
284+
271285
def test_sql_auto_register_arrow_table():
272286
ctx = SessionContext(auto_register_python_variables=True)
273287
arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841

0 commit comments

Comments
 (0)