Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/compiler-memory-planning.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ Users attempting to write a custom memory planning algorithm should start by loo

## Device-Aware Memory Planning

When `enable_non_cpu_memory_planning=True` is set on `ExecutorchBackendConfig`,
When `enable_non_cpu_memory_planning=True` is set on `MemoryPlanningPass`,
the memory planning pass partitions tensor specs by their device type and runs
the planning algorithm independently for each device. This produces separate
memory buffers for each device (e.g. CPU vs. CUDA), ensuring that device memory
Expand All @@ -103,7 +103,7 @@ and host memory are never mixed.
```python
program = edge_program.to_executorch(
exir.ExecutorchBackendConfig(
enable_non_cpu_memory_planning=True,
memory_planning_pass=exir.passes.MemoryPlanningPass(enable_non_cpu_memory_planning=True),
)
)
```
Expand Down
26 changes: 7 additions & 19 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -13,6 +13,7 @@
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass
from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.tracer import ExirDynamoConfig
from torch.fx._compatibility import compatibility
Expand Down Expand Up @@ -60,7 +61,12 @@
# A single memory planning pass can be defined for all the programs in the
# EdgeProgramManager or can be defined per program.
memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass()
to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False)

# A single propagate device config can be defined for all the programs in the
# EdgeProgramManager or can be defined per program.
propagate_device_config: Union[PropagateDeviceConfig, Dict[str, PropagateDeviceConfig]] = field(default_factory=PropagateDeviceConfig)

to_out_var_pass: PassType = field(default_factory=lambda: ToOutVarPass(ignore_to_out_var_failure=False))
dynamic_memory_planning_mode: DynamicMemoryPlanningMode = (
DynamicMemoryPlanningMode.UPPER_BOUND
)
Expand Down Expand Up @@ -117,21 +123,3 @@

# Experimental: If set to true, we run a pass to reinplace ops in the graph.
run_reinplace_pass: bool = False

# When True, memory planning partitions specs by device and runs the
# algorithm independently per device, producing separate buffers for CPU
# vs. accelerator memory. Default False preserves the legacy behavior
# where all tensors are planned into CPU memory regardless of device.
enable_non_cpu_memory_planning: bool = False

# When True, method-level input tensors that feed directly into a device
# delegate are NOT wrapped with _h2d_copy. The user must provide tensors
# already on the target device. Useful for pipelines where inputs are
# pre-staged on GPU.
skip_h2d_for_method_inputs: bool = False

# When True, device delegate outputs that are directly method outputs
# are NOT wrapped with _d2h_copy. The method outputs stay on device.
# Useful for cross-method GPU pipelines where the next method consumes
# GPU tensors directly.
skip_d2h_for_method_outputs: bool = False
9 changes: 5 additions & 4 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -2185,9 +2185,10 @@
ExecutorBackendPartitioner()
).to_executorch()

# Check that there is only one delegate because two methods are exactly the same
# Check that there are two delegates now because the
# passes might apply differently due to per-method config support.
self.assertEqual(
len(edge_program_manager.executorch_program.backend_delegate_data), 1
len(edge_program_manager.executorch_program.backend_delegate_data), 2
)

def test_delegate_deduplicate_with_different_compile_specs(self) -> None:
Expand Down Expand Up @@ -2710,7 +2711,7 @@
)
lowered = edge.to_backend(DevicePartitioner())
et_prog = lowered.to_executorch(
config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
config=ExecutorchBackendConfig(memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True)),
)
program = et_prog._emitter_output.program

Expand Down Expand Up @@ -2741,7 +2742,7 @@
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
et_prog = edge.to_executorch(
config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True),
config=ExecutorchBackendConfig(memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True)),
)
program = et_prog._emitter_output.program

Expand Down
6 changes: 3 additions & 3 deletions exir/passes/_device_copy_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
Registry for device copy ops used to insert explicit H2D (host-to-device)
and D2H (device-to-host) data transfer operations at delegate boundaries.

These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning
is True, making the graph functional by explicitly transferring data between
CPU and device memory.
These ops are inserted by PropagateDevicePass when memory planning is configured with
enable_non_cpu_memory_planning=True, making the graph functional by explicitly
transferring data between CPU and device memory.

Follows the same registration pattern as dim_order_ops_registry.py.
"""
Expand Down
6 changes: 6 additions & 0 deletions exir/passes/memory_planning_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def __init__(
self.alloc_mutable_buffers = alloc_mutable_buffers
self.share_mutable_buffers = share_mutable_buffers
self.alignment = alignment

# When True, memory planning partitions specs by device and runs the
# algorithm independently per device, producing separate buffers for CPU
# vs. accelerator memory. Default False preserves the legacy behavior
# where all tensors are planned into CPU memory regardless of device.
# A dict can be used to set per-method values, keyed by method name.
self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning
self.state = _MemoryPlanningState()

Expand Down
31 changes: 30 additions & 1 deletion exir/passes/propagate_device_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -9,7 +9,9 @@
import copy
import logging
import operator
from typing import Optional
from typing import Optional, Union, Dict
from dataclasses import dataclass
from torch.fx._compatibility import compatibility

# Import to register the et_copy ops so torch.ops.et_copy is available.
import executorch.exir.passes._device_copy_ops_registry # noqa: F401
Expand All @@ -29,6 +31,33 @@
# with this key and a value encoding the device string (e.g., b"cuda:0").
TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"

@compatibility(is_backward_compatible=False)
@dataclass
class PropagateDeviceConfig:
# When True, method-level input tensors that feed directly into a device
# delegate are NOT wrapped with _h2d_copy. The user must provide tensors
# already on the target device. Useful for pipelines where inputs are
# pre-staged on GPU.
# A dict can be used to set per-method values, keyed by method name.
skip_h2d_for_method_inputs: Union[bool, Dict[str, bool]] = False

# When True, device delegate outputs that are directly method outputs
# are NOT wrapped with _d2h_copy. The method outputs stay on device.
# Useful for cross-method GPU pipelines where the next method consumes
# GPU tensors directly.
# A dict can be used to set per-method values, keyed by method name.
skip_d2h_for_method_outputs: Union[bool, Dict[str, bool]] = False

def __hash__(self):
return hash((str(self.skip_h2d_for_method_inputs), str(self.skip_d2h_for_method_outputs)))

def __eq__(self, other: object) -> bool:
if not isinstance(other, PropagateDeviceConfig):
return False
return (self.skip_h2d_for_method_inputs == other.skip_h2d_for_method_inputs and
self.skip_d2h_for_method_outputs == other.skip_d2h_for_method_outputs)



def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]:
"""
Expand Down
22 changes: 12 additions & 10 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
Expand Down Expand Up @@ -59,7 +59,7 @@
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass
from executorch.exir.passes.propagate_device_pass import PropagateDevicePass, PropagateDeviceConfig
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
from executorch.exir.passes.reinplace import reinplace_pass
from executorch.exir.passes.remove_graph_asserts_pass import (
Expand Down Expand Up @@ -843,16 +843,24 @@
Returns a list of passes to lower from edge to executorch.
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
"""
# Handle propagate device config
propagate_device_config = config.propagate_device_config
if isinstance(propagate_device_config, dict):
device_cfg = propagate_device_config.get(name, PropagateDeviceConfig())
else:
device_cfg = propagate_device_config

passes: List[PassType] = [
# ExecuTorch backend ops are unable to handle unbacked symints. So after
# this pass, passes cannot be Interpreter-based, because it will fail if
# there exists an unbacked symint operation.
*config.passes,
SpecPropPass(),
# By default instantiate the pass without arguments since we compare object ids later
PropagateDevicePass(
skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs,
skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs,
),
skip_h2d_for_method_inputs=device_cfg.skip_h2d_for_method_inputs,
skip_d2h_for_method_outputs=device_cfg.skip_d2h_for_method_outputs,
) if device_cfg.skip_h2d_for_method_inputs or device_cfg.skip_d2h_for_method_outputs else PropagateDevicePass(),
EdgeToBackendOpsPass(),
RemoveGraphAssertsPass(),
] + pre_memory_planning_passes(config, name)
Expand Down Expand Up @@ -1801,12 +1809,6 @@
)
else:
memory_planning_pass = config.memory_planning_pass
# Propagate enable_non_cpu_memory_planning from the top-level config
# to the pass instance so that device-aware partitioning is applied.
if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"):
memory_planning_pass.enable_non_cpu_memory_planning = (
config.enable_non_cpu_memory_planning
)
# TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work
if hasattr(memory_planning_pass, "run"):
new_gm_res = memory_planning_pass.run(new_gm, new_signature)
Expand Down
22 changes: 13 additions & 9 deletions exir/tests/test_propagate_device_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -32,8 +32,10 @@
from executorch.exir.passes.propagate_device_pass import (
_get_target_device_from_compile_specs,
_parse_device_spec_value,
PropagateDeviceConfig,
TARGET_DEVICE_COMPILE_SPEC_KEY,
)
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.schema import DeviceType
from executorch.exir.tensor import TensorSpec
from torch.export import export
Expand Down Expand Up @@ -726,8 +728,8 @@
inputs = (torch.randn(2, 2), torch.randn(2, 2))
et_config = ExecutorchBackendConfig(
emit_stacktrace=False,
skip_h2d_for_method_inputs=True,
enable_non_cpu_memory_planning=True,
propagate_device_config=PropagateDeviceConfig(skip_h2d_for_method_inputs=True),
memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True),
)

for pipeline, program, gm in self._get_executorch_program(
Expand Down Expand Up @@ -782,8 +784,8 @@
inputs = (torch.randn(2, 2), torch.randn(2, 2))
et_config = ExecutorchBackendConfig(
emit_stacktrace=False,
skip_d2h_for_method_outputs=True,
enable_non_cpu_memory_planning=True,
propagate_device_config=PropagateDeviceConfig(skip_d2h_for_method_outputs=True),
memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True),
)

for pipeline, program, gm in self._get_executorch_program(
Expand Down Expand Up @@ -836,9 +838,11 @@
inputs = (torch.randn(2, 2), torch.randn(2, 2))
et_config = ExecutorchBackendConfig(
emit_stacktrace=False,
skip_h2d_for_method_inputs=True,
skip_d2h_for_method_outputs=True,
enable_non_cpu_memory_planning=True,
propagate_device_config=PropagateDeviceConfig(
skip_h2d_for_method_inputs=True,
skip_d2h_for_method_outputs=True,
),
memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True),
)

for pipeline, program, gm in self._get_executorch_program(
Expand Down Expand Up @@ -912,8 +916,8 @@
inputs = (torch.randn(2, 2), torch.randn(2, 2))
et_config = ExecutorchBackendConfig(
emit_stacktrace=False,
skip_h2d_for_method_inputs=True,
enable_non_cpu_memory_planning=True,
propagate_device_config=PropagateDeviceConfig(skip_h2d_for_method_inputs=True),
memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True),
)

for pipeline, program, gm in self._get_executorch_program(
Expand Down
2 changes: 1 addition & 1 deletion runtime/executor/test/method_meta_test.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -247,7 +247,7 @@
programs_["add_with_device"]->method_meta("forward");
ASSERT_EQ(method_meta.error(), Error::Ok);

// ModuleAddWithDevice exports with enable_non_cpu_memory_planning=True.
// ModuleAddWithDevice exports with memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True).
// The model delegates add(a,b) to CUDA, producing:
// non_const_buffer_sizes: [0, 48] (index 0 reserved)
// non_const_buffer_device: [{buffer_idx=1, device_type=CUDA,
Expand Down
2 changes: 1 addition & 1 deletion test/models/export_program_with_device_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand Down Expand Up @@ -102,7 +102,7 @@
et_prog = lowered.to_executorch(
ExecutorchBackendConfig( # type: ignore[call-arg]
emit_stacktrace=False,
enable_non_cpu_memory_planning=True,
memory_planning_pass=exir.passes.MemoryPlanningPass(enable_non_cpu_memory_planning=True),

Check failure on line 105 in test/models/export_program_with_device_info.py

View workflow job for this annotation

GitHub Actions / lintrunner

FLAKE8 F821

undefined name 'exir' See https://www.flake8rules.com/rules/F821.html.

Check failure on line 105 in test/models/export_program_with_device_info.py

View workflow job for this annotation

GitHub Actions / lintrunner-mypy

MYPY name-defined

Name "exir" is not defined To disable, use ` # type: ignore[name-defined]`
)
)

Expand Down
Loading