Skip to content

Commit 776cc1c

Browse files
Fix: Allow init to be walked to track its dependencies
1 parent 5211a57 commit 776cc1c

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

sqlmesh/utils/metaprogramming.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
352352
walk(base, base.__qualname__, is_metadata)
353353

354354
for k, v in obj.__dict__.items():
355-
if k.startswith("__"):
355+
# skip dunder methods bar __init__ as it might contain user defined logic with cross class references
356+
if k.startswith("__") and k != "__init__":
356357
continue
357358

358359
# Traverse methods in a class to find global references
@@ -362,10 +363,14 @@ def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None:
362363
if callable(v):
363364
# Walk the method if it's part of the object, else it's a global function and we just store it
364365
if v.__qualname__.startswith(obj.__qualname__):
365-
for k, v in func_globals(v).items():
366-
walk(v, k, is_metadata)
367-
else:
368-
walk(v, v.__name__, is_metadata)
366+
try:
367+
for k, v in func_globals(v).items():
368+
walk(v, k, is_metadata)
369+
except (OSError, TypeError):
370+
# __init__ may come from built-ins or wrapped callables
371+
pass
372+
else:
373+
walk(v, k, is_metadata)
369374
elif callable(obj):
370375
for k, v in func_globals(obj).items():
371376
walk(v, k, is_metadata)

tests/utils/test_metaprogramming.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,48 @@ def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None:
460460
assert serialized_env == expected_env
461461

462462

463+
class ReferencedClass:
464+
def __init__(self, value: int):
465+
self.value = value
466+
467+
def get_value(self) -> int:
468+
return self.value
469+
470+
471+
class ClassThatReferencesAnother:
472+
def __init__(self, x: int):
473+
self.helper = ReferencedClass(x * 2)
474+
475+
def compute(self) -> int:
476+
return self.helper.get_value() + 10
477+
478+
479+
def function_using_class_with_reference(y: int) -> int:
480+
obj = ClassThatReferencesAnother(y)
481+
return obj.compute()
482+
483+
484+
def test_serialize_env_with_class_referencing_another_class() -> None:
485+
# firstly we can confirm that func_globals picks up the reference
486+
init_globals = func_globals(ClassThatReferencesAnother.__init__)
487+
assert "ReferencedClass" in init_globals
488+
489+
path = Path("tests/utils")
490+
env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {}
491+
492+
# build ajd serialize environment for the function that uses the class
493+
build_env(function_using_class_with_reference, env=env, name="test_func", path=path)
494+
serialized_env = serialize_env(env, path=path)
495+
496+
# both classes should be in the serialized environment
497+
assert "ClassThatReferencesAnother" in serialized_env
498+
assert "ReferencedClass" in serialized_env
499+
500+
prepared_env = prepare_env(serialized_env)
501+
result = eval("test_func(33)", prepared_env)
502+
assert result == 76
503+
504+
463505
def test_dict_sort_basic_types():
464506
"""Test dict_sort with basic Python types."""
465507
# Test basic types that should use standard repr

0 commit comments

Comments
 (0)