Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 253 additions & 0 deletions backends/cuda/passes/weight_offload_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""Weight offloading pass for the CUDA backend.

EXPERIMENTAL -- NOT YET WIRED.

This module is a DESIGN DOCUMENT. ``_apply_weight_offload`` has no
implementation body (its body is ``...``); the ``executorch_weight_offload::probe``
custom op is registered with torch but is not preserved through inductor
lowering (see the "Probe op preservation" open item below); and
``CudaPartitioner`` does NOT yet expose a ``weight_offload`` kwarg that
would invoke this pass. Nothing in this PR calls the function or the
custom op.

Open item -- probe op preservation:
The pass's design intent is to insert one
``executorch_weight_offload::probe(w)`` call before every consumer of
every parameter placeholder, so the runtime can intercept each
parameter read via an AOTI c-shim. The current registration
(``custom_op(..., mutates_args=())`` plus an identity fake) is NOT
sufficient to keep the op alive through inductor's CSE + fusion
passes -- a side-effect-free identity op is a textbook DCE target,
and there is no test in this PR that asserts the lowered AOTI
wrapper actually emits an ``aoti_torch_cuda_probe`` call per
``(consumer, weight)`` pair. The implementation PR must:
1. Give probe explicit non-elidable semantics that survive
inductor lowering AND don't clash with torch.export's
parameter-output validation (parameters are read-only by
convention, so ``mutates_args={"w"}`` won't work directly).
2. Land a regression test that lowers a tiny two-linear model
and asserts the wrapper.cpp contains the expected
``aoti_torch_cuda_probe`` calls in the expected places.
Without both, weight offload silently degrades to eager loading at
runtime even when "enabled" -- which is exactly the surprise this
feature exists to prevent.

Open item -- payload transport:
The pass needs to run from a custom pass inside
``AotiBackend.preprocess`` -- the ``ExportedProgram`` partitioner
contract forbids mutating the program from ``partition()`` (see
``exir/backend/partitioner.py:83``), so neither ``CudaPartitioner``
nor any other partitioner can run the rewrite directly. ``preprocess``
has two candidate channels for the resulting payload (schedule,
floor, pin_fqns, version), both accessible from there and neither
yet wired:
(a) Serialize into the partition's ``processed_bytes`` (currently
``b""`` for AOTI; see ``backends/aoti/aoti_backend.py``).
(b) Attach a per-method ``NamedDataStore`` entry (where
``AotiBackend`` already writes ``_so_blob`` and
``_weights_blob``).
Pick one when wiring. The payload-key constants below are channel-
agnostic.

Open item -- schedule / cursor order:
The runtime cursor in ``Session::serve`` hard-fails on a mismatch
against the recorded schedule (see ``register_schedule`` in
``weight_offload.h``). That contract requires the execution-order
FQN list this pass records to match the order the lowered AOTI
wrapper actually invokes ``aoti_torch_cuda_probe`` in. The pass
observes the graph BEFORE inductor's custom passes / decompositions
/ lowering run (``backends/aoti/aoti_backend.py:206``), all of which
can reorder or duplicate parameter reads. Two options for the
implementation PR: (a) regenerate the schedule from the post-lowering
wrapper order; or (b) extend probe with an explicit
``probe_id`` / FQN argument so each call self-identifies and the
runtime needs no cursor at all. (b) is the more robust choice --
it removes the entire class of "graph order drifted from wrapper
order" silent failures, at the cost of a wider op signature.

The runtime half lives in ``backends/cuda/runtime/weight_offload/``,
which is also marked EXPERIMENTAL -- NOT YET WIRED.
"""

import torch
from torch.library import custom_op, register_fake


_OP_NAMESPACE = "executorch_weight_offload"
_OP_QUALNAME = f"{_OP_NAMESPACE}::probe"


@custom_op(_OP_QUALNAME, mutates_args=())
def probe(w: torch.Tensor) -> torch.Tensor:
"""Identity passthrough in eager. CUDA backend replaces via c-shim at AOTI compile time.

Inserted by ``apply_weight_offload`` before every consumer of every
parameter (or buffer) placeholder. The CUDA runtime's c-shim
(``aoti_torch_cuda_probe``) intercepts each call at runtime and serves
bytes through the bounded GPU pool.

Signature is deliberately minimal — no FQN or schedule-index argument.
The runtime resolves which weight is being probed by looking up the
input tensor's ``data_ptr()`` in the ``ProbeRegistry`` populated by
``Session::bind_placeholder_constants`` at backend init.

Notes:
- The current ``mutates_args=()`` is insufficient: an identity op
with no side effect is a textbook DCE target for inductor.
``mutates_args={"w"}`` clashes with torch.export's
parameter-output validation (parameters are read-only by
convention). The implementation PR must find a third option;
see the "Probe op preservation" open item in the module
docstring above.
- Weight offloading is mutually exclusive with the CUDA backend's
``enable_cuda_graph_for_method`` option: CUDA-graph Replay bypasses
AOTI's ``run()``, so probe ops never fire. The runtime hard-fails
at ``init`` if both are set for the same method.
"""
return w


@register_fake(_OP_QUALNAME)
def _probe_fake(w: torch.Tensor) -> torch.Tensor:
# Fresh fake tensor so inductor doesn't decide to inline the op away.
return torch.empty_like(w)


PROBE_OP_TARGET = torch.ops.executorch_weight_offload.probe.default


# Payload field names. INTERNAL design intent for the partition-payload
# (or NamedDataStore -- see the "payload transport" open item in the
# module docstring) that ``CudaBackend.preprocess`` would write and
# ``cuda_backend.cpp::init`` would parse once wired. Names are
# namespaced by method so prefill and decode each get their own payload
# in the same .pte.
PAYLOAD_KEY_VERSION = "version"
PAYLOAD_KEY_METHOD_NAME = "method_name"
PAYLOAD_KEY_SCHEDULE = "schedule"
PAYLOAD_KEY_FLOOR = "floor_bytes"
PAYLOAD_KEY_PIN_FQNS = "pin_fqns"

# Schema version for the emitted offload payload. Bumped whenever the
# shape of any field above changes (e.g. switching the floor from
# uint64 bytes to a struct with prefetch headroom, or switching the
# budget wire option from ``weight_offload_budget_mb`` to a bytes-typed
# field once ``BackendOptions`` grows int64 support). The runtime
# hard-fails at ``CudaBackend::init`` if the version is missing or
# unknown, naming the expected range — version drift surfaces loudly at
# load instead of silently mis-parsing a payload.
SCHEMA_VERSION = 1


def _apply_weight_offload(
exported_program,
*,
method_name: str,
pin_fqns: list[str] | None = None,
) -> dict:
"""In-place graph rewrite + offload payload computation.

INTERNAL — leading underscore is the Python signal. The only supported
caller is ``CudaPartitioner`` (see ``backends/cuda/cuda_partitioner.py``),
which sources ``method_name`` from its compile specs so prefill
and decode get distinct payloads instead of colliding.
``method_name`` is REQUIRED (no default) precisely so a direct
caller importing this function for a multi-method model cannot
silently collide all methods on ``"forward"``.

Inserts ``probe(w)`` in front of every consumer of every parameter (or
buffer) placeholder, rewriting the consumer's arg to read the probe's
output. One probe call per ``(consumer, weight)`` pair — not per
weight — so the runtime can re-load a weight that was evicted between
two uses inside the same forward pass.

Pinned FQNs still get probes inserted (so the runtime serves them
through the same path), but they are EXCLUDED from the schedule and
EXCLUDED from the floor calculation. See ``Session::Config::pin_fqns``
and ``Session::register_schedule`` in
``backends/cuda/runtime/weight_offload/weight_offload.h``.

The pass is the single authoritative source for the pin set. The
runtime has NO pin-set option; it parses the list out of the
partition payload and passes it to the Session unchanged. Pin set
affects floor correctness (the floor is computed assuming pinned
FQNs do not stream), so the runtime cannot override it.

AOTI constant-folding contract:
The pass operates on parameter placeholders in the ExportedProgram.
AOTI knobs that fold parameters out of the container at compile
time (so they no longer appear in ``get_constant_name(idx)``)
break offload: the pass cannot insert probes for parameters it
cannot see, and the folded constants would be loaded eagerly
through the normal blob path at runtime — silently defeating
offload and reintroducing the OOM this feature exists to prevent.

Exports that enable weight offload must NOT have AOTI fold
parameter constants. The exact ``torch._inductor.config`` knob
and its required value is verified in the implementation PR.
The pass hard-fails at export if it detects folded parameters
(placeholder count below the expected catalog count derived
from ``exported_program.state_dict``); the runtime hard-fails
again at ``Session::bind_placeholder_constants`` if any catalog
FQN is missing a probe binding — defense in depth against the
two halves drifting.

Metadata transport:
The returned ``dict`` is the offload payload that the
implementation PR will route to ``cuda_backend.cpp::init`` via the
AOTI ``preprocess`` path (see the "payload transport" open item
in the module docstring for the design constraint and the two
candidate channels: ``processed_bytes`` vs. ``NamedDataStore``).

Args:
exported_program: an ``ExportedProgram`` produced by
``torch.export.export``. Mutated in place: probe nodes are
inserted, consumer args are rewritten.
method_name: the method this pass is being applied to. Returned
verbatim in the payload so the runtime can validate which
method the bytes belong to.
pin_fqns: FQNs to mark as always-resident. Optional. The list
is propagated verbatim into the payload AND used by this pass
to exclude those FQNs from the schedule and the floor
calculation. Pinning an FQN that does not appear as a parameter
placeholder is a hard error.

Returns: a ``dict`` with the keys defined at module scope (all
internal payload, not opt-in signals):

- ``"version"``: ``int`` schema version (currently ``1``). Runtime
hard-fails on unknown version.
- ``"method_name"``: ``str``. Echoed for runtime validation.
- ``"schedule"``: ``list[str]`` of NON-PINNED parameter FQNs in
execution order. Drives the runtime cursor + prefetch (or is
obviated entirely if the implementation PR picks option (b) of
the "Schedule / cursor order" open item -- a probe-id arg).
Pinned FQNs do not appear here; their probes take a separate
fast-path in ``Session::serve`` that does not touch the cursor.
- ``"floor_bytes"``: ``int`` — minimum GPU byte budget for the
streaming portion of the working set (``max-over-consecutive-
kernel-pairs of (sum bytes K_i + K_{i+1}) + max single weight
size``), computed over the schedule above (i.e. excluding
pinned weights). The runtime asserts
``(weight_offload_budget_mb << 20) - pinned_bytes`` covers
this; below-floor budgets hard-fail at init with the required
minimum spelled out.
- ``"pin_fqns"``: ``list[str]`` of FQNs the runtime keeps
resident. Empty if ``pin_fqns`` is unset.

The opt-in signal is intended to be a separate
``CompileSpec("weight_offload", b"1")`` emitted by ``CudaPartitioner``
when wired -- the enable signal lives in exactly one place, the
compile spec, rather than being duplicated across compile spec +
payload. Neither the compile spec nor the partitioner kwarg exists
in this PR; see the EXPERIMENTAL banner at the top of this module.

Not called by users. Not called by anything in this PR.
"""
...
10 changes: 10 additions & 0 deletions backends/cuda/runtime/cuda_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,16 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle {

// CUDA graph state (warmup, capture, replay, static buffers)
CudaGraphState cuda_graph_state;

// Weight offloading: the per-handle ``unique_ptr<weight_offload::Session>``
// field lands with the implementation PR alongside ``weight_offload.cpp``,
// which provides the out-of-line ``Session::~Session()`` definition that
// ``unique_ptr<Session>``'s implicit destructor needs. Adding the field
// here in an API-surface-only PR would force every TU that includes this
// header into an unresolved-symbol link error against
// ``Session::~Session()``. See
// ``backends/cuda/runtime/weight_offload/weight_offload.h`` for the
// ownership model that explains why the Session lives per-handle.
};

} // namespace cuda
Expand Down
Loading
Loading