-
Notifications
You must be signed in to change notification settings - Fork 21
fix: instrument PyTorch nn.Module forward method calls via instance #1418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
When optimizing a `forward` method on a class (e.g., AlexNet.forward), the test pattern `model = AlexNet(...); model(input_data)` wasn't being instrumented because the call `model(input_data)` didn't match the expected function name "forward". This fix adds special handling for the PyTorch nn.Module pattern: - Collect variable names assigned from class instantiations - Also wrap calls to those instance variables when optimizing `forward` Fixes the "Ignoring test case that passed but had no runtime" error when running codeflash on PyTorch model forward methods. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
| def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef: | ||
| if node.name.startswith("test_"): | ||
| # Collect instance variables for forward method instrumentation (PyTorch pattern) | ||
| self.collect_instance_variables(node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.
Consider clearing the set at the start of each test function:
| self.collect_instance_variables(node) | |
| self.instance_variable_names.clear() | |
| self.collect_instance_variables(node) |
PR Review SummaryPrek Checks✅ All prek checks pass (ruff check, ruff format). No fixes needed. MypyCode ReviewThe implementation is clean and well-scoped. One minor observation:
No critical bugs, security vulnerabilities, or breaking API changes found. Test Coverage
Codeflash Optimization PRsChecked 7 open optimization PRs targeting Last updated: 2026-02-06T |
⚡️ Codeflash found optimizations for this PR📄 769% (7.69x) speedup for
|
Summary
nn.Moduleforward method when called via instance (e.g.,model(input_data))model = ClassName(...); model(input_data)wheremodel(input_data)internally callsforward()forwardmethodsProblem
When running
codeflash --function AlexNet.forward, tests with this pattern weren't being instrumented:The instrumentation was looking for direct calls to
forwardorAlexNet, butmodel(input_data)matched neither.Solution
collect_instance_variables()to track variables assigned from class instantiationsfind_and_update_line_node()to wrap calls to instance variables when optimizingforwardmethodsTest plan
test_pytorch_forward_method_instrumentationtest case🤖 Generated with Claude Code