Skip to content

feat: Add Inductor backend config templates#688

Open
JewelRoam wants to merge 1 commit intoPaddlePaddle:developfrom
JewelRoam:bench
Open

feat: Add Inductor backend config templates#688
JewelRoam wants to merge 1 commit intoPaddlePaddle:developfrom
JewelRoam:bench

Conversation

@JewelRoam
Copy link
Copy Markdown
Collaborator

Overview

This PR introduces a flexible configuration system for PyTorch Inductor backend,
allowing users to select predefined config templates that set groups of
torch._inductor.config overrides. This provides an extension to PyTorch's
official "mode" concept while maintaining full compatibility with the existing
test_compiler.py framework.

Motivation

Previously, the InductorBackend accepted only basic config parameters through
individual inductor_config dictionary entries. Users could not easily enable
common combinations of Inductor options such as:

  • CUTLASS-based GEMM kernels - For optimal GEMM performance on modern GPUs
  • CUDA Graphs - To reduce kernel launch overhead for small batch inference
  • Model freezing - To inline weights as constants for deployment optimization
  • TMA (Tensor Memory Accelerator) - For H100+ GPUs with hardware acceleration

This PR addresses these limitations by introducing config templates - pre-defined,
well-tested combinations of torch._inductor.config options that users can select
by name.

Changes Summary

1. Inductor Backend Configuration Templates

File: graph_net_bench/torch/backend/inductor_backend.py

New Features

  • _INDUCTOR_CONFIG_TEMPLATES dictionary with 8 predefined templates
  • _TEMPLATE_TO_COMPILE_MODE mapping for templates that imply compile modes
  • _set_nested_attr() utility function for setting nested config attributes

Supported Templates

Template Description Config Overrides Compile Mode
triton Default Triton backend with cpp_wrapper disabled default
cpp_wrapper Use C++ wrapper for generated kernels default
cutlass Enable max-autotune for CUTLASS GEMM kernels default
aten Enable ATen fallback for debugging default
cudagraphs Enable CUDA Graphs for reduced overhead reduce-overhead
max_autotune Comprehensive autotuning across backends max-autotune
freezing Enable model freezing (inline weights) default
tma Enable TMA persistent matmul kernels default

TMA Graceful Fallback

The TMA template has built-in graceful fallback behavior:

  • On H100+ (Hopper, CC >= 9.0): Enables TMA persistent kernels
  • On A100 or other GPUs: Enables non-TMA persistent kernels as fallback
  • Configuration is always accepted regardless of underlying hardware
  • No runtime error occurs on GPUs without TMA support

This ensures the template works universally while still leveraging TMA
benefits when available.

Enhanced Configuration Interface

config = {
    "template": "max_autotune",     # Select predefined template
    "mode": "reduce-overhead",       # Override compile mode
    "freezing": True,                  # Top-level override
    "inductor_config": {              # Arbitrary config overrides
        "triton.cudagraphs": False,   # Nested config (dotted paths)
        "triton.enable_persistent_tma_matmul": True,  # Deeply nested
    }
}

Configuration Priority (highest to lowest)

  1. inductor_config - Explicit user-specified overrides
  2. freezing - Top-level convenience flag
  3. template - Predefined template defaults

2. CUDA Graphs Compatibility Fix

File: graph_net_bench/torch/test_compiler.py

Problem

When CUDA Graphs is enabled (via triton.cudagraphs or mode="reduce-overhead"):

  • PyTorch records output tensor pointers to CUDA Graph buffers
  • Subsequent model calls overwrite the same buffer locations
  • Test framework runs eager model, then compiled model sequentially
  • Accessing compiled output after eager run triggers:
    RuntimeError: Error: accessing tensor output of CUDAGraphs
    that has been overwritten by a subsequent run.
    

Solution

Clone model outputs immediately after model invocation:

# Clone outputs to prevent CUDA Graphs buffer overwrite issues.
if isinstance(outs, torch.Tensor):
    outs = outs.clone()
elif isinstance(outs, tuple):
    outs = tuple(t.clone() if isinstance(t, torch.Tensor) else t for t in outs)

Impact

  • Fixed in: test_compiler.py
  • Not affected: eval_backend_perf.py and eval_backend_diff.py
    (they use torch.save/torch.load which creates independent copies)

3. Comprehensive Test Suite

File: test/inductor_backend_test.py (new file, 323 lines)

Test Structure

# Three test classes covering different aspects:

class TestInductorBackendTemplates(unittest.TestCase):
    # 10 tests for template configuration validation

class TestInductorBackendIntegration(unittest.TestCase):
    # 2 tests for overall backend integration

class TestInductorConfigValidation(unittest.TestCase):
    # 15 tests for PyTorch config validation
    # 1 test for TMA fallback on non-TMA GPUs

Test Coverage

Category Tests Coverage
Template validation 10 All 8 templates produce correct overrides and modes
Integration 2 Backend initialization and override priority
Config validation 15 All config keys exist in torch._inductor.config
TMA fallback 1 Graceful fallback on non-TMA GPUs (e.g., A100)
Total 28 Full test coverage

Validation Results

$ python -m unittest test.inductor_backend_test -v

test_all_templates_exist ... ok
test_aten_template ... ok
test_cpp_wrapper_template ... ok
test_cudagraphs_template ... ok
test_cutlass_template ... ok
test_empty_config ... ok
test_freezing_override ... ok
test_freezing_template ... ok
test_inductor_config_override ... ok
test_invalid_template_raises_error ... ok
test_max_autotune_template ... ok
test_mode_override ... ok
test_tma_fallback_on_non_tma_gpu ... ok
test_tma_template ... ok
test_template_to_mode_mapping ... ok
test_triton_template ... ok
# ... (all 28 tests pass)

Ran 28 tests in 0.002s
OK

Usage Examples

Basic Template Usage

# Use CUTLASS-based GEMM kernels
python -m graph_net_bench.torch.test_compiler \
    --compiler inductor \
    --model-path samples/torchvision/alexnet \
    --config "eyJ0ZW1wbGF0ZSI6ICJjdXRsYXNzIn0=" \
    --trials 5 --warmup 3

Template with Custom Mode

# Use max-autotune template but disable CUDA Graphs
python -c "
import json, base64
config = {
    'template': 'max_autotune',
    'mode': 'max-autotune-no-cudagraphs'
}
print(base64.b64encode(json.dumps(config).encode()).decode())
"

python -m graph_net_bench.torch.test_compiler \
    --compiler inductor \
    --model-path samples/torchvision/alexnet \
    --config "$(python -c ...)" \
    --trials 5 --warmup 3

Combined Configuration

# Combine template with additional options
python -c "
import json, base64
config = {
    'template': 'max_autotune',
    'freezing': True,
    'inductor_config': {
        'coordinate_descent_tuning': False
    }
}
print(base64.b64encode(json.dumps(config).encode()).decode())
"

All Available Templates

Template Name Description Base64 Config
triton Default Triton backend eyJ0ZW1wbGF0ZSI6ICJ0cml0b24ifQ==
cpp_wrapper C++ wrapper for kernels eyJ0ZW1wbGF0ZSI6ICJjcHBfd3JhcHBlciJ9
cutlass CUTLASS GEMM kernels eyJ0ZW1wbGF0ZSI6ICJjdXRsYXNzIn0=
aten ATen fallback for debugging eyJ0ZW1wbGF0ZSI6ICJhdGVuIn0=
cudagraphs CUDA Graphs for reduced overhead eyJ0ZW1wbGF0ZSI6ICJjdWRhZ3JhcGhzIn0=
max_autotune Maximum autotuning eyJ0ZW1wbGF0ZSI6ICJtYXhfYXV0b3R1bmUifQ==
freezing Model freezing eyJ0ZW1wbGF0ZSI6ICJmcmVlemluZyJ9
tma Tensor Memory Accelerator eyJ0ZW1wbGF0ZSI6ICJ0bWEifQ==

Testing

Manual Testing Results

All 8 templates have been manually tested with actual model compilation:

Template Tested Result Notes
triton Works correctly
cpp_wrapper Works correctly
cutlass AUTOTUNE output visible in logs
aten Works correctly
cudagraphs Fixed - no more CUDA Graphs errors
max_autotune AUTOTUNE output visible
freezing Works correctly
tma Graceful fallback on A100

Automated Testing

$ python -m unittest test.inductor_backend_test
Ran 28 tests in 0.002s
OK

Breaking Changes

None. This is a purely additive feature with no changes to existing behavior
when the template parameter is not specified.

Migration Guide

For users wanting to use the new template system:

  1. Basic Usage: Add --config with base64-encoded template
  2. Combined Usage: Mix templates with additional config options
  3. Custom Config: Use inductor_config for any custom settings

No code changes required - this is fully backward compatible.

Documentation References

All configuration keys have been verified against PyTorch 2.7.1 source code:

Hardware Requirements

Feature Minimum Requirement
TMA (Tensor Memory Accelerator) NVIDIA H100+ (Compute Capability >= 9.0)
CUDA Graphs Any CUDA-capable GPU
CUTLASS GEMM Separate CUTLASS backend installation required
Triton CUDA-capable GPU with compute capability >= 7.0

Performance Considerations

  • CUDA Graphs: Best for small batch sizes, may increase memory usage
  • CUTLASS: Significantly faster GEMM for large tensors
  • TMA: Memory bandwidth optimization, requires H100+ hardware
  • Freezing: Reduces memory footprint, weights cannot be updated after
  • Max Autotune: Longer compilation time but optimal performance

Checklist

  • All config keys verified against PyTorch source code
  • All templates produce correct overrides
  • Unit tests added (28 tests)
  • Unit tests pass (28/28 OK)
  • Manual testing completed for all 8 templates
  • CUDA Graphs compatibility issue fixed
  • Documentation updated with inline comments
  • Usage examples provided
  • Backward compatibility maintained
  • Code follows PyTorch naming conventions

Related Issues

  • Fixes CUDA Graphs output buffer overwrite issue in test framework
  • Provides graceful TMA fallback for non-TMA GPUs (e.g., A100)

## Overview

This PR introduces a flexible configuration system for PyTorch Inductor backend
with 8 predefined config templates, CUDA Graphs compatibility fix,
and comprehensive unit tests (28 tests total).

## Changes

- Inductor backend with 8 config templates (triton, cpp_wrapper, cutlass,
  aten, cudagraphs, max_autotune, freezing, tma)
- CUDA Graphs output buffer overwrite fix in test_compiler.py
- 28 unit tests in test/inductor_backend_test.py

## Testing

- All config keys verified against PyTorch 2.7.1 source code
- All templates tested with actual model compilation
- Unit tests pass: 28/28 OK
- TMA config gracefully falls back on non-TMA GPUs (A100)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 15, 2026

Thanks for your contribution!

@JewelRoam JewelRoam changed the title feat: Add Inductor backend config templates and comprehensive test suite feat: Add Inductor backend config templates Apr 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant