diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 4366468d0..d86a695ab 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -79,9 +79,53 @@ def __init__( self.only_function_name = function.function_name self.module_path = module_path self.call_positions = call_positions + # Track instance variables when optimizing forward methods (PyTorch nn.Module pattern) + self.instance_variable_names: set[str] = set() if len(function.parents) == 1 and function.parents[0].type == "ClassDef": self.class_name = function.top_level_parent_name + def collect_instance_variables(self, func_node: ast.FunctionDef) -> None: + """Collect variable names that are instances of the target class. + + This handles the PyTorch nn.Module pattern where: + model = AlexNet(...) + model(input_data) # calls __call__ which invokes forward() + + When optimizing ClassName.forward, we need to track variables assigned + from ClassName(...) so we can instrument calls to those variables. + """ + if self.class_name is None or self.only_function_name != "forward": + return + + class_name = self.class_name + instance_vars = self.instance_variable_names + + # Manually traverse only assignment nodes instead of walking entire tree + nodes_to_check = list(func_node.body) + while nodes_to_check: + node = nodes_to_check.pop() + + # Look for assignments like: model = ClassName(...) + if isinstance(node, ast.Assign): + value = node.value + if isinstance(value, ast.Call): + func = value.func + if isinstance(func, ast.Name) and func.id == class_name: + for target in node.targets: + if isinstance(target, ast.Name): + instance_vars.add(target.id) + + # Add nested statements to check + if hasattr(node, "body"): + nodes_to_check.extend(node.body) + if hasattr(node, "orelse"): + nodes_to_check.extend(node.orelse) + if hasattr(node, "finalbody"): + nodes_to_check.extend(node.finalbody) + if hasattr(node, "handlers"): + for handler in node.handlers: + nodes_to_check.extend(handler.body) + def find_and_update_line_node( self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None ) -> Iterable[ast.stmt] | None: @@ -122,7 +166,16 @@ def iter_ast_calls(node): codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load()) for node in iter_ast_calls(test_node): - if not node_in_call_position(node, self.call_positions): + # Check if this call is at a known position OR is an instance variable call + # for forward methods (PyTorch nn.Module pattern) + is_at_call_position = node_in_call_position(node, self.call_positions) + is_instance_call = ( + isinstance(node.func, ast.Name) + and node.func.id in self.instance_variable_names + and self.only_function_name == "forward" + ) + + if not is_at_call_position and not is_instance_call: continue call_node = node @@ -134,7 +187,8 @@ def iter_ast_calls(node): function_name = node_func.id # Check if this is the function we want to instrument - if function_name != fn_obj.function_name: + # Also match instance variable calls for forward methods + if function_name != fn_obj.function_name and function_name not in self.instance_variable_names: continue if fn_obj.is_async: @@ -325,6 +379,9 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: if node.name.startswith("test_"): + # Collect instance variables for forward method instrumentation (PyTorch pattern) + self.collect_instance_variables(node) + did_update = False i = len(node.body) - 1 while i >= 0: diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index a8cd75b70..c5a6ab19f 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -3306,3 +3306,77 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): finally: test_path.unlink(missing_ok=True) + + +def test_pytorch_forward_method_instrumentation() -> None: + """Test instrumentation of PyTorch nn.Module forward method when called via instance(). + + This tests the pattern: + model = MyModule(...) + model(input_data) # calls __call__ which invokes forward() + + The instrumentation should wrap the instance call even though the position + recorded is where the class is referenced, not where the instance is called. + """ + code = """ +class MockModule: + def __init__(self, num_classes=10): + self.num_classes = num_classes + + def forward(self, x): + return x * 2 + +def test_module(): + model = MockModule(num_classes=10) + input_data = 5 + result = model(input_data) + assert result == 10 +""" + code_path = Path(tempfile.gettempdir()) / "mock_module.py" + test_path = Path(tempfile.gettempdir()) / "test_mock_module.py" + + try: + with code_path.open("w") as f: + f.write(code) + + with test_path.open("w") as f: + f.write(code) + + func = FunctionToOptimize( + function_name="forward", + parents=[FunctionParent("MockModule", "ClassDef")], + file_path=code_path, + starting_line=6, + ending_line=7, + is_async=False, + ) + + # Position where MockModule is called (line 10 in 1-indexed: model = MockModule(...)) + call_positions = [CodePosition(line_no=10, col_no=12)] + + success, new_test = inject_profiling_into_existing_test( + test_path, + call_positions, + func, + test_path.parent, + mode=TestingMode.PERFORMANCE, + ) + + assert success + assert new_test is not None + + # The key assertion: model(input_data) should be wrapped with codeflash_wrap + # The wrap should be around 'model', passing the instance as the callable + assert "codeflash_wrap(model," in new_test, ( + "Expected model(input_data) to be wrapped as codeflash_wrap(model, ..., input_data), " + f"but got:\n{new_test}" + ) + + # Verify the function name in the wrap is the qualified name (MockModule.forward) + assert "MockModule.forward" in new_test, ( + f"Expected 'MockModule.forward' to appear in the instrumented code, but got:\n{new_test}" + ) + + finally: + code_path.unlink(missing_ok=True) + test_path.unlink(missing_ok=True)