Skip to content

Conversation

@aseembits93
Copy link
Contributor

Summary

  • Fix instrumentation of PyTorch nn.Module forward method when called via instance (e.g., model(input_data))
  • Add special handling for the pattern: model = ClassName(...); model(input_data) where model(input_data) internally calls forward()
  • Resolves "Ignoring test case that passed but had no runtime" error when optimizing forward methods

Problem

When running codeflash --function AlexNet.forward, tests with this pattern weren't being instrumented:

model = AlexNet(num_classes=10)
result = model(input_data)  # calls __call__ which invokes forward()

The instrumentation was looking for direct calls to forward or AlexNet, but model(input_data) matched neither.

Solution

  1. Added collect_instance_variables() to track variables assigned from class instantiations
  2. Modified find_and_update_line_node() to wrap calls to instance variables when optimizing forward methods
  3. Added test case specifically for this PyTorch pattern

Test plan

  • Added test_pytorch_forward_method_instrumentation test case
  • All existing instrumentation tests pass (19/19)
  • Verified with actual codeflash command - runtime is now properly measured

🤖 Generated with Claude Code

When optimizing a `forward` method on a class (e.g., AlexNet.forward),
the test pattern `model = AlexNet(...); model(input_data)` wasn't being
instrumented because the call `model(input_data)` didn't match the
expected function name "forward".

This fix adds special handling for the PyTorch nn.Module pattern:
- Collect variable names assigned from class instantiations
- Also wrap calls to those instance variables when optimizing `forward`

Fixes the "Ignoring test case that passed but had no runtime" error
when running codeflash on PyTorch model forward methods.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
if node.name.startswith("test_"):
# Collect instance variables for forward method instrumentation (PyTorch pattern)
self.collect_instance_variables(node)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (non-blocking): instance_variable_names is accumulated across all test functions without being cleared. If a file has multiple test functions, variable names collected from test_a will persist when processing test_b. This could cause false-positive instrumentation if a variable name from one test happens to be called in another.

Consider clearing the set at the start of each test function:

Suggested change
self.collect_instance_variables(node)
self.instance_variable_names.clear()
self.collect_instance_variables(node)

@claude
Copy link

claude bot commented Feb 6, 2026

PR Review Summary

Prek Checks

✅ All prek checks pass (ruff check, ruff format). No fixes needed.

Mypy

⚠️ All mypy errors in the changed files are pre-existing (122 errors across both files, none introduced by this PR). The new code follows the same patterns as the existing codebase.

Code Review

The implementation is clean and well-scoped. One minor observation:

  • Non-blocking: instance_variable_names set is not cleared between test functions in the same InjectPerfOnly visitor instance. Variables collected from one test function could leak into the processing of another. See inline comment for a suggested one-line fix.

No critical bugs, security vulnerabilities, or breaking API changes found.

Test Coverage

File Main PR Change
codeflash/code_utils/instrument_existing_tests.py 55% (423 stmts, 192 miss) 56% (437 stmts, 192 miss) +1%
  • New code coverage: All 14 new statements are covered by the new test_pytorch_forward_method_instrumentation test ✅
  • No coverage regression
  • The new test verifies the instrumentation output correctly wraps model(input_data) as codeflash_wrap(model, ...) with the qualified name MockModule.forward

Codeflash Optimization PRs

Checked 7 open optimization PRs targeting main. None have all CI checks passing — no PRs merged.


Last updated: 2026-02-06T

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Feb 6, 2026

⚡️ Codeflash found optimizations for this PR

📄 769% (7.69x) speedup for InjectPerfOnly.collect_instance_variables in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 1.30 milliseconds 150 microseconds (best of 15 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch fix/pytorch-forward-method-instrumentation).

Static Badge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant