From 3fb2d071e29bcc4055725b83a00ca490c50db9c4 Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Thu, 14 May 2026 00:35:26 -0700 Subject: [PATCH] [ET Device Support] make device support config method-based MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Device memory planning and H2D/D2H copy configuration were previously global flags on ExecutorchBackendConfig, applied uniformly across all methods in a multi-method program. This made it impossible to configure different behaviors per method — e.g., skipping H2D copies for one method while keeping them for another. This diff makes these configs method-based by: 1. Moving `enable_non_cpu_memory_planning` into MemoryPlanningPass, which already supports per-method dispatch via Dict[str, PassType] on ExecutorchBackendConfig.memory_planning_pass, follow other memory-planing-related config like `alloc_input` or `alloc_output` 2. Introducing PropagateDeviceConfig dataclass that groups `skip_h2d_for_method_inputs` and `skip_d2h_for_method_outputs`, with each field accepting either a single bool or a Dict[str, bool] for per-method overrides. A new `propagate_device_config` field on ExecutorchBackendConfig similarly accepts either a single config or Dict[str, PropagateDeviceConfig]. Differential Revision: [D101243687](https://our.internmc.facebook.com/intern/diff/D101243687/) [ghstack-poisoned] --- docs/source/compiler-memory-planning.md | 4 +-- exir/capture/_config.py | 26 +++++----------- exir/emit/test/test_emit.py | 9 +++--- exir/passes/_device_copy_ops_registry.py | 6 ++-- exir/passes/memory_planning_pass.py | 6 ++++ exir/passes/propagate_device_pass.py | 31 ++++++++++++++++++- exir/program/_program.py | 22 +++++++------ exir/tests/test_propagate_device_pass.py | 22 +++++++------ runtime/executor/test/method_meta_test.cpp | 2 +- .../models/export_program_with_device_info.py | 2 +- 10 files changed, 80 insertions(+), 50 deletions(-) diff --git a/docs/source/compiler-memory-planning.md b/docs/source/compiler-memory-planning.md index e6b6ca16283..5bebb46b34c 100644 --- a/docs/source/compiler-memory-planning.md +++ b/docs/source/compiler-memory-planning.md @@ -94,7 +94,7 @@ Users attempting to write a custom memory planning algorithm should start by loo ## Device-Aware Memory Planning -When `enable_non_cpu_memory_planning=True` is set on `ExecutorchBackendConfig`, +When `enable_non_cpu_memory_planning=True` is set on `MemoryPlanningPass`, the memory planning pass partitions tensor specs by their device type and runs the planning algorithm independently for each device. This produces separate memory buffers for each device (e.g. CPU vs. CUDA), ensuring that device memory @@ -103,7 +103,7 @@ and host memory are never mixed. ```python program = edge_program.to_executorch( exir.ExecutorchBackendConfig( - enable_non_cpu_memory_planning=True, + memory_planning_pass=exir.passes.MemoryPlanningPass(enable_non_cpu_memory_planning=True), ) ) ``` diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 4ff70095041..ab2867c9d14 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -13,6 +13,7 @@ from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass +from executorch.exir.passes.propagate_device_pass import PropagateDeviceConfig from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass from executorch.exir.tracer import ExirDynamoConfig from torch.fx._compatibility import compatibility @@ -60,7 +61,12 @@ class ExecutorchBackendConfig: # A single memory planning pass can be defined for all the programs in the # EdgeProgramManager or can be defined per program. memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass() - to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False) + + # A single propagate device config can be defined for all the programs in the + # EdgeProgramManager or can be defined per program. + propagate_device_config: Union[PropagateDeviceConfig, Dict[str, PropagateDeviceConfig]] = field(default_factory=PropagateDeviceConfig) + + to_out_var_pass: PassType = field(default_factory=lambda: ToOutVarPass(ignore_to_out_var_failure=False)) dynamic_memory_planning_mode: DynamicMemoryPlanningMode = ( DynamicMemoryPlanningMode.UPPER_BOUND ) @@ -117,21 +123,3 @@ class ExecutorchBackendConfig: # Experimental: If set to true, we run a pass to reinplace ops in the graph. run_reinplace_pass: bool = False - - # When True, memory planning partitions specs by device and runs the - # algorithm independently per device, producing separate buffers for CPU - # vs. accelerator memory. Default False preserves the legacy behavior - # where all tensors are planned into CPU memory regardless of device. - enable_non_cpu_memory_planning: bool = False - - # When True, method-level input tensors that feed directly into a device - # delegate are NOT wrapped with _h2d_copy. The user must provide tensors - # already on the target device. Useful for pipelines where inputs are - # pre-staged on GPU. - skip_h2d_for_method_inputs: bool = False - - # When True, device delegate outputs that are directly method outputs - # are NOT wrapped with _d2h_copy. The method outputs stay on device. - # Useful for cross-method GPU pipelines where the next method consumes - # GPU tensors directly. - skip_d2h_for_method_outputs: bool = False diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 4bf97f60da4..4611f227772 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -2185,9 +2185,10 @@ def forward(self, x): ExecutorBackendPartitioner() ).to_executorch() - # Check that there is only one delegate because two methods are exactly the same + # Check that there are two delegates now because the + # passes might apply differently due to per-method config support. self.assertEqual( - len(edge_program_manager.executorch_program.backend_delegate_data), 1 + len(edge_program_manager.executorch_program.backend_delegate_data), 2 ) def test_delegate_deduplicate_with_different_compile_specs(self) -> None: @@ -2710,7 +2711,7 @@ def forward(self, a, b): ) lowered = edge.to_backend(DevicePartitioner()) et_prog = lowered.to_executorch( - config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True), + config=ExecutorchBackendConfig(memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True)), ) program = et_prog._emitter_output.program @@ -2741,7 +2742,7 @@ def forward(self, a, b): compile_config=EdgeCompileConfig(_check_ir_validity=False), ) et_prog = edge.to_executorch( - config=ExecutorchBackendConfig(enable_non_cpu_memory_planning=True), + config=ExecutorchBackendConfig(memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True)), ) program = et_prog._emitter_output.program diff --git a/exir/passes/_device_copy_ops_registry.py b/exir/passes/_device_copy_ops_registry.py index a62b88d4234..3dd87aa47e2 100644 --- a/exir/passes/_device_copy_ops_registry.py +++ b/exir/passes/_device_copy_ops_registry.py @@ -8,9 +8,9 @@ Registry for device copy ops used to insert explicit H2D (host-to-device) and D2H (device-to-host) data transfer operations at delegate boundaries. -These ops are inserted by PropagateDevicePass when enable_non_cpu_memory_planning -is True, making the graph functional by explicitly transferring data between -CPU and device memory. +These ops are inserted by PropagateDevicePass when memory planning is configured with +enable_non_cpu_memory_planning=True, making the graph functional by explicitly +transferring data between CPU and device memory. Follows the same registration pattern as dim_order_ops_registry.py. """ diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 32c343a4607..faeb1d8ca76 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -174,6 +174,12 @@ def __init__( self.alloc_mutable_buffers = alloc_mutable_buffers self.share_mutable_buffers = share_mutable_buffers self.alignment = alignment + + # When True, memory planning partitions specs by device and runs the + # algorithm independently per device, producing separate buffers for CPU + # vs. accelerator memory. Default False preserves the legacy behavior + # where all tensors are planned into CPU memory regardless of device. + # A dict can be used to set per-method values, keyed by method name. self.enable_non_cpu_memory_planning = enable_non_cpu_memory_planning self.state = _MemoryPlanningState() diff --git a/exir/passes/propagate_device_pass.py b/exir/passes/propagate_device_pass.py index 8c14ce52ea8..8506965dd6c 100644 --- a/exir/passes/propagate_device_pass.py +++ b/exir/passes/propagate_device_pass.py @@ -9,7 +9,9 @@ import copy import logging import operator -from typing import Optional +from typing import Optional, Union, Dict +from dataclasses import dataclass +from torch.fx._compatibility import compatibility # Import to register the et_copy ops so torch.ops.et_copy is available. import executorch.exir.passes._device_copy_ops_registry # noqa: F401 @@ -29,6 +31,33 @@ # with this key and a value encoding the device string (e.g., b"cuda:0"). TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device" +@compatibility(is_backward_compatible=False) +@dataclass +class PropagateDeviceConfig: + # When True, method-level input tensors that feed directly into a device + # delegate are NOT wrapped with _h2d_copy. The user must provide tensors + # already on the target device. Useful for pipelines where inputs are + # pre-staged on GPU. + # A dict can be used to set per-method values, keyed by method name. + skip_h2d_for_method_inputs: Union[bool, Dict[str, bool]] = False + + # When True, device delegate outputs that are directly method outputs + # are NOT wrapped with _d2h_copy. The method outputs stay on device. + # Useful for cross-method GPU pipelines where the next method consumes + # GPU tensors directly. + # A dict can be used to set per-method values, keyed by method name. + skip_d2h_for_method_outputs: Union[bool, Dict[str, bool]] = False + + def __hash__(self): + return hash((str(self.skip_h2d_for_method_inputs), str(self.skip_d2h_for_method_outputs))) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PropagateDeviceConfig): + return False + return (self.skip_h2d_for_method_inputs == other.skip_h2d_for_method_inputs and + self.skip_d2h_for_method_outputs == other.skip_d2h_for_method_outputs) + + def _parse_device_spec_value(value: bytes) -> tuple[schema.DeviceType, int]: """ diff --git a/exir/program/_program.py b/exir/program/_program.py index 89f598ef3f4..9a2a943855a 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -59,7 +59,7 @@ from executorch.exir.passes.normalize_view_copy_base_pass import ( NormalizeViewCopyBasePass, ) -from executorch.exir.passes.propagate_device_pass import PropagateDevicePass +from executorch.exir.passes.propagate_device_pass import PropagateDevicePass, PropagateDeviceConfig from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass from executorch.exir.passes.reinplace import reinplace_pass from executorch.exir.passes.remove_graph_asserts_pass import ( @@ -843,16 +843,24 @@ def edge_to_executorch_passes( Returns a list of passes to lower from edge to executorch. Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass. """ + # Handle propagate device config + propagate_device_config = config.propagate_device_config + if isinstance(propagate_device_config, dict): + device_cfg = propagate_device_config.get(name, PropagateDeviceConfig()) + else: + device_cfg = propagate_device_config + passes: List[PassType] = [ # ExecuTorch backend ops are unable to handle unbacked symints. So after # this pass, passes cannot be Interpreter-based, because it will fail if # there exists an unbacked symint operation. *config.passes, SpecPropPass(), + # By default instantiate the pass without arguments since we compare object ids later PropagateDevicePass( - skip_h2d_for_method_inputs=config.skip_h2d_for_method_inputs, - skip_d2h_for_method_outputs=config.skip_d2h_for_method_outputs, - ), + skip_h2d_for_method_inputs=device_cfg.skip_h2d_for_method_inputs, + skip_d2h_for_method_outputs=device_cfg.skip_d2h_for_method_outputs, + ) if device_cfg.skip_h2d_for_method_inputs or device_cfg.skip_d2h_for_method_outputs else PropagateDevicePass(), EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), ] + pre_memory_planning_passes(config, name) @@ -1801,12 +1809,6 @@ def to_executorch( # noqa (FLAKE8) C901 ) else: memory_planning_pass = config.memory_planning_pass - # Propagate enable_non_cpu_memory_planning from the top-level config - # to the pass instance so that device-aware partitioning is applied. - if hasattr(memory_planning_pass, "enable_non_cpu_memory_planning"): - memory_planning_pass.enable_non_cpu_memory_planning = ( - config.enable_non_cpu_memory_planning - ) # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work if hasattr(memory_planning_pass, "run"): new_gm_res = memory_planning_pass.run(new_gm, new_signature) diff --git a/exir/tests/test_propagate_device_pass.py b/exir/tests/test_propagate_device_pass.py index 696c339344b..ee882953d5b 100644 --- a/exir/tests/test_propagate_device_pass.py +++ b/exir/tests/test_propagate_device_pass.py @@ -32,8 +32,10 @@ from executorch.exir.passes.propagate_device_pass import ( _get_target_device_from_compile_specs, _parse_device_spec_value, + PropagateDeviceConfig, TARGET_DEVICE_COMPILE_SPEC_KEY, ) +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.schema import DeviceType from executorch.exir.tensor import TensorSpec from torch.export import export @@ -726,8 +728,8 @@ def forward(self, a, b): inputs = (torch.randn(2, 2), torch.randn(2, 2)) et_config = ExecutorchBackendConfig( emit_stacktrace=False, - skip_h2d_for_method_inputs=True, - enable_non_cpu_memory_planning=True, + propagate_device_config=PropagateDeviceConfig(skip_h2d_for_method_inputs=True), + memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True), ) for pipeline, program, gm in self._get_executorch_program( @@ -782,8 +784,8 @@ def forward(self, a, b): inputs = (torch.randn(2, 2), torch.randn(2, 2)) et_config = ExecutorchBackendConfig( emit_stacktrace=False, - skip_d2h_for_method_outputs=True, - enable_non_cpu_memory_planning=True, + propagate_device_config=PropagateDeviceConfig(skip_d2h_for_method_outputs=True), + memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True), ) for pipeline, program, gm in self._get_executorch_program( @@ -836,9 +838,11 @@ def forward(self, a, b): inputs = (torch.randn(2, 2), torch.randn(2, 2)) et_config = ExecutorchBackendConfig( emit_stacktrace=False, - skip_h2d_for_method_inputs=True, - skip_d2h_for_method_outputs=True, - enable_non_cpu_memory_planning=True, + propagate_device_config=PropagateDeviceConfig( + skip_h2d_for_method_inputs=True, + skip_d2h_for_method_outputs=True, + ), + memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True), ) for pipeline, program, gm in self._get_executorch_program( @@ -912,8 +916,8 @@ def forward(self, a, b): inputs = (torch.randn(2, 2), torch.randn(2, 2)) et_config = ExecutorchBackendConfig( emit_stacktrace=False, - skip_h2d_for_method_inputs=True, - enable_non_cpu_memory_planning=True, + propagate_device_config=PropagateDeviceConfig(skip_h2d_for_method_inputs=True), + memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True), ) for pipeline, program, gm in self._get_executorch_program( diff --git a/runtime/executor/test/method_meta_test.cpp b/runtime/executor/test/method_meta_test.cpp index 3e6e09cc8c3..23b8506ebd5 100644 --- a/runtime/executor/test/method_meta_test.cpp +++ b/runtime/executor/test/method_meta_test.cpp @@ -247,7 +247,7 @@ TEST_F(MethodMetaTest, MethodMetaBufferDeviceReturnsCudaForDeviceBuffer) { programs_["add_with_device"]->method_meta("forward"); ASSERT_EQ(method_meta.error(), Error::Ok); - // ModuleAddWithDevice exports with enable_non_cpu_memory_planning=True. + // ModuleAddWithDevice exports with memory_planning_pass=MemoryPlanningPass(enable_non_cpu_memory_planning=True). // The model delegates add(a,b) to CUDA, producing: // non_const_buffer_sizes: [0, 48] (index 0 reserved) // non_const_buffer_device: [{buffer_idx=1, device_type=CUDA, diff --git a/test/models/export_program_with_device_info.py b/test/models/export_program_with_device_info.py index 3b6af55c6e8..b7d394e2455 100644 --- a/test/models/export_program_with_device_info.py +++ b/test/models/export_program_with_device_info.py @@ -102,7 +102,7 @@ def main() -> None: et_prog = lowered.to_executorch( ExecutorchBackendConfig( # type: ignore[call-arg] emit_stacktrace=False, - enable_non_cpu_memory_planning=True, + memory_planning_pass=exir.passes.MemoryPlanningPass(enable_non_cpu_memory_planning=True), ) )