Skip to content

Commit 1764a57

Browse files
committed
test: add tests for case-insensitive lookup and skipping None shadowing in SQL queries
1 parent 904c1ca commit 1764a57

File tree

2 files changed

+48
-4
lines changed

2 files changed

+48
-4
lines changed

python/datafusion/context.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,32 @@ def _lookup_python_object(name: str) -> Any | None:
784784
try:
785785
if frame is not None:
786786
frame = frame.f_back
787+
lower_name = name.lower()
788+
789+
def _match(mapping: dict[str, Any]) -> Any | None:
790+
if not mapping:
791+
return None
792+
793+
value = mapping.get(name)
794+
if value is not None:
795+
return value
796+
797+
for key, candidate in mapping.items():
798+
if isinstance(key, str) and key.lower() == lower_name:
799+
if candidate is not None:
800+
return candidate
801+
802+
return None
803+
787804
while frame is not None:
788805
locals_dict = frame.f_locals
789-
if name in locals_dict:
790-
return locals_dict[name]
806+
match = _match(locals_dict)
807+
if match is not None:
808+
return match
791809
globals_dict = frame.f_globals
792-
if name in globals_dict:
793-
return globals_dict[name]
810+
match = _match(globals_dict)
811+
if match is not None:
812+
return match
794813
frame = frame.f_back
795814
finally:
796815
del frame

python/tests/test_context.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,31 @@ def run_query():
337337
assert result[0].column(0).to_pylist()[0] == 4
338338

339339

340+
def test_sql_auto_register_skips_none_shadowing():
341+
ctx = SessionContext(auto_register_python_objects=True)
342+
mytable = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841
343+
344+
def run_query():
345+
mytable = None # noqa: F841
346+
return ctx.sql(
347+
"SELECT SUM(value) AS total FROM mytable",
348+
).collect()
349+
350+
batches = run_query()
351+
assert batches[0].column(0).to_pylist()[0] == 6
352+
353+
354+
def test_sql_auto_register_case_insensitive_lookup():
355+
ctx = SessionContext(auto_register_python_objects=True)
356+
MyTable = pa.Table.from_pydict({"value": [2, 3]}) # noqa: F841
357+
358+
batches = ctx.sql(
359+
"SELECT SUM(value) AS total FROM mytable",
360+
).collect()
361+
362+
assert batches[0].column(0).to_pylist()[0] == 5
363+
364+
340365
def test_sql_auto_register_pandas_dataframe():
341366
pd = pytest.importorskip("pandas")
342367

0 commit comments

Comments
 (0)