diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index d025f6f35..3e45fbed5 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -97,14 +97,34 @@ def collect_instance_variables(self, func_node: ast.FunctionDef) -> None: if self.class_name is None or self.only_function_name != "forward": return - for node in ast.walk(func_node): + class_name = self.class_name + instance_vars = self.instance_variable_names + + # Manually traverse only assignment nodes instead of walking entire tree + nodes_to_check = list(func_node.body) + while nodes_to_check: + node = nodes_to_check.pop() + # Look for assignments like: model = ClassName(...) if isinstance(node, ast.Assign): - if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name): - if node.value.func.id == self.class_name: + value = node.value + if isinstance(value, ast.Call): + func = value.func + if isinstance(func, ast.Name) and func.id == class_name: for target in node.targets: if isinstance(target, ast.Name): - self.instance_variable_names.add(target.id) + instance_vars.add(target.id) + + # Add nested statements to check + if hasattr(node, 'body'): + nodes_to_check.extend(node.body) + if hasattr(node, 'orelse'): + nodes_to_check.extend(node.orelse) + if hasattr(node, 'finalbody'): + nodes_to_check.extend(node.finalbody) + if hasattr(node, 'handlers'): + for handler in node.handlers: + nodes_to_check.extend(handler.body) def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None