From 1a0ca1062b3f2ddc5d724f725da457e91668e297 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Thu, 21 May 2026 10:14:17 -0700 Subject: [PATCH] Weight offloading design surface (CUDA backend) Design-only PR for CUDA-backend weight offloading: weights live in CPU memory, the runtime streams only the currently-needed ones to GPU through a capped cudaMemPool. Headers and docstrings only -- no implementation bodies, no caller, no wiring on the partitioner or runtime side. All four design files are marked ``EXPERIMENTAL -- NOT YET WIRED``. Public knobs (``CudaPartitioner(weight_offload=True, ...)`` and the ``weight_offload_budget_mb`` runtime spec) are intentionally NOT exposed in this PR; they ship with the implementation. Four open items block wiring and are documented inline: * Probe op preservation -- an identity custom op with ``mutates_args=()`` is a DCE target through inductor; the implementation PR must give probe non-elidable semantics that don't trip torch.export's parameter-output validation, plus a test that asserts the lowered AOTI wrapper actually emits the probe calls. * AOTI blob layout -- ``WeightCatalog::build`` needs per-constant offsets and dtype/shape. AOTI doesn't expose either today; implementation PR must either land upstream shims or serialize the metadata into the offload payload at export time. * Payload transport channel -- the pass has to run from ``AotiBackend.preprocess`` (the partitioner contract forbids mutating the ExportedProgram from ``partition()``); the implementation PR picks between ``processed_bytes`` and a per-method ``NamedDataStore`` entry. * Schedule / cursor order -- the runtime cursor hard-fails on a mismatch against the recorded schedule, but the pass observes parameter order before inductor lowering reorders / duplicates reads. Implementation PR either regenerates the schedule from the post-lowering wrapper or extends probe with a self-identifying ``probe_id`` / FQN arg so no cursor is needed. Read order: backends/cuda/passes/weight_offload_pass.py -- export half backends/cuda/runtime/weight_offload/weight_offload.h -- runtime backends/cuda/runtime/weight_offload/probe_op.h -- c-shim backends/cuda/runtime/weight_offload/prefetcher.h -- copy stream See: https://github.com/pytorch/executorch/issues/19709 --- backends/cuda/passes/weight_offload_pass.py | 253 ++++++ backends/cuda/runtime/cuda_delegate_handle.h | 10 + .../cuda/runtime/weight_offload/prefetcher.h | 183 ++++ .../cuda/runtime/weight_offload/probe_op.h | 60 ++ .../runtime/weight_offload/weight_offload.h | 811 ++++++++++++++++++ 5 files changed, 1317 insertions(+) create mode 100644 backends/cuda/passes/weight_offload_pass.py create mode 100644 backends/cuda/runtime/weight_offload/prefetcher.h create mode 100644 backends/cuda/runtime/weight_offload/probe_op.h create mode 100644 backends/cuda/runtime/weight_offload/weight_offload.h diff --git a/backends/cuda/passes/weight_offload_pass.py b/backends/cuda/passes/weight_offload_pass.py new file mode 100644 index 00000000000..3b7478e23a9 --- /dev/null +++ b/backends/cuda/passes/weight_offload_pass.py @@ -0,0 +1,253 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Weight offloading pass for the CUDA backend. + +EXPERIMENTAL -- NOT YET WIRED. + +This module is a DESIGN DOCUMENT. ``_apply_weight_offload`` has no +implementation body (its body is ``...``); the ``executorch_weight_offload::probe`` +custom op is registered with torch but is not preserved through inductor +lowering (see the "Probe op preservation" open item below); and +``CudaPartitioner`` does NOT yet expose a ``weight_offload`` kwarg that +would invoke this pass. Nothing in this PR calls the function or the +custom op. + +Open item -- probe op preservation: + The pass's design intent is to insert one + ``executorch_weight_offload::probe(w)`` call before every consumer of + every parameter placeholder, so the runtime can intercept each + parameter read via an AOTI c-shim. The current registration + (``custom_op(..., mutates_args=())`` plus an identity fake) is NOT + sufficient to keep the op alive through inductor's CSE + fusion + passes -- a side-effect-free identity op is a textbook DCE target, + and there is no test in this PR that asserts the lowered AOTI + wrapper actually emits an ``aoti_torch_cuda_probe`` call per + ``(consumer, weight)`` pair. The implementation PR must: + 1. Give probe explicit non-elidable semantics that survive + inductor lowering AND don't clash with torch.export's + parameter-output validation (parameters are read-only by + convention, so ``mutates_args={"w"}`` won't work directly). + 2. Land a regression test that lowers a tiny two-linear model + and asserts the wrapper.cpp contains the expected + ``aoti_torch_cuda_probe`` calls in the expected places. + Without both, weight offload silently degrades to eager loading at + runtime even when "enabled" -- which is exactly the surprise this + feature exists to prevent. + +Open item -- payload transport: + The pass needs to run from a custom pass inside + ``AotiBackend.preprocess`` -- the ``ExportedProgram`` partitioner + contract forbids mutating the program from ``partition()`` (see + ``exir/backend/partitioner.py:83``), so neither ``CudaPartitioner`` + nor any other partitioner can run the rewrite directly. ``preprocess`` + has two candidate channels for the resulting payload (schedule, + floor, pin_fqns, version), both accessible from there and neither + yet wired: + (a) Serialize into the partition's ``processed_bytes`` (currently + ``b""`` for AOTI; see ``backends/aoti/aoti_backend.py``). + (b) Attach a per-method ``NamedDataStore`` entry (where + ``AotiBackend`` already writes ``_so_blob`` and + ``_weights_blob``). + Pick one when wiring. The payload-key constants below are channel- + agnostic. + +Open item -- schedule / cursor order: + The runtime cursor in ``Session::serve`` hard-fails on a mismatch + against the recorded schedule (see ``register_schedule`` in + ``weight_offload.h``). That contract requires the execution-order + FQN list this pass records to match the order the lowered AOTI + wrapper actually invokes ``aoti_torch_cuda_probe`` in. The pass + observes the graph BEFORE inductor's custom passes / decompositions + / lowering run (``backends/aoti/aoti_backend.py:206``), all of which + can reorder or duplicate parameter reads. Two options for the + implementation PR: (a) regenerate the schedule from the post-lowering + wrapper order; or (b) extend probe with an explicit + ``probe_id`` / FQN argument so each call self-identifies and the + runtime needs no cursor at all. (b) is the more robust choice -- + it removes the entire class of "graph order drifted from wrapper + order" silent failures, at the cost of a wider op signature. + +The runtime half lives in ``backends/cuda/runtime/weight_offload/``, +which is also marked EXPERIMENTAL -- NOT YET WIRED. +""" + +import torch +from torch.library import custom_op, register_fake + + +_OP_NAMESPACE = "executorch_weight_offload" +_OP_QUALNAME = f"{_OP_NAMESPACE}::probe" + + +@custom_op(_OP_QUALNAME, mutates_args=()) +def probe(w: torch.Tensor) -> torch.Tensor: + """Identity passthrough in eager. CUDA backend replaces via c-shim at AOTI compile time. + + Inserted by ``apply_weight_offload`` before every consumer of every + parameter (or buffer) placeholder. The CUDA runtime's c-shim + (``aoti_torch_cuda_probe``) intercepts each call at runtime and serves + bytes through the bounded GPU pool. + + Signature is deliberately minimal — no FQN or schedule-index argument. + The runtime resolves which weight is being probed by looking up the + input tensor's ``data_ptr()`` in the ``ProbeRegistry`` populated by + ``Session::bind_placeholder_constants`` at backend init. + + Notes: + - The current ``mutates_args=()`` is insufficient: an identity op + with no side effect is a textbook DCE target for inductor. + ``mutates_args={"w"}`` clashes with torch.export's + parameter-output validation (parameters are read-only by + convention). The implementation PR must find a third option; + see the "Probe op preservation" open item in the module + docstring above. + - Weight offloading is mutually exclusive with the CUDA backend's + ``enable_cuda_graph_for_method`` option: CUDA-graph Replay bypasses + AOTI's ``run()``, so probe ops never fire. The runtime hard-fails + at ``init`` if both are set for the same method. + """ + return w + + +@register_fake(_OP_QUALNAME) +def _probe_fake(w: torch.Tensor) -> torch.Tensor: + # Fresh fake tensor so inductor doesn't decide to inline the op away. + return torch.empty_like(w) + + +PROBE_OP_TARGET = torch.ops.executorch_weight_offload.probe.default + + +# Payload field names. INTERNAL design intent for the partition-payload +# (or NamedDataStore -- see the "payload transport" open item in the +# module docstring) that ``CudaBackend.preprocess`` would write and +# ``cuda_backend.cpp::init`` would parse once wired. Names are +# namespaced by method so prefill and decode each get their own payload +# in the same .pte. +PAYLOAD_KEY_VERSION = "version" +PAYLOAD_KEY_METHOD_NAME = "method_name" +PAYLOAD_KEY_SCHEDULE = "schedule" +PAYLOAD_KEY_FLOOR = "floor_bytes" +PAYLOAD_KEY_PIN_FQNS = "pin_fqns" + +# Schema version for the emitted offload payload. Bumped whenever the +# shape of any field above changes (e.g. switching the floor from +# uint64 bytes to a struct with prefetch headroom, or switching the +# budget wire option from ``weight_offload_budget_mb`` to a bytes-typed +# field once ``BackendOptions`` grows int64 support). The runtime +# hard-fails at ``CudaBackend::init`` if the version is missing or +# unknown, naming the expected range — version drift surfaces loudly at +# load instead of silently mis-parsing a payload. +SCHEMA_VERSION = 1 + + +def _apply_weight_offload( + exported_program, + *, + method_name: str, + pin_fqns: list[str] | None = None, +) -> dict: + """In-place graph rewrite + offload payload computation. + + INTERNAL — leading underscore is the Python signal. The only supported + caller is ``CudaPartitioner`` (see ``backends/cuda/cuda_partitioner.py``), + which sources ``method_name`` from its compile specs so prefill + and decode get distinct payloads instead of colliding. + ``method_name`` is REQUIRED (no default) precisely so a direct + caller importing this function for a multi-method model cannot + silently collide all methods on ``"forward"``. + + Inserts ``probe(w)`` in front of every consumer of every parameter (or + buffer) placeholder, rewriting the consumer's arg to read the probe's + output. One probe call per ``(consumer, weight)`` pair — not per + weight — so the runtime can re-load a weight that was evicted between + two uses inside the same forward pass. + + Pinned FQNs still get probes inserted (so the runtime serves them + through the same path), but they are EXCLUDED from the schedule and + EXCLUDED from the floor calculation. See ``Session::Config::pin_fqns`` + and ``Session::register_schedule`` in + ``backends/cuda/runtime/weight_offload/weight_offload.h``. + + The pass is the single authoritative source for the pin set. The + runtime has NO pin-set option; it parses the list out of the + partition payload and passes it to the Session unchanged. Pin set + affects floor correctness (the floor is computed assuming pinned + FQNs do not stream), so the runtime cannot override it. + + AOTI constant-folding contract: + The pass operates on parameter placeholders in the ExportedProgram. + AOTI knobs that fold parameters out of the container at compile + time (so they no longer appear in ``get_constant_name(idx)``) + break offload: the pass cannot insert probes for parameters it + cannot see, and the folded constants would be loaded eagerly + through the normal blob path at runtime — silently defeating + offload and reintroducing the OOM this feature exists to prevent. + + Exports that enable weight offload must NOT have AOTI fold + parameter constants. The exact ``torch._inductor.config`` knob + and its required value is verified in the implementation PR. + The pass hard-fails at export if it detects folded parameters + (placeholder count below the expected catalog count derived + from ``exported_program.state_dict``); the runtime hard-fails + again at ``Session::bind_placeholder_constants`` if any catalog + FQN is missing a probe binding — defense in depth against the + two halves drifting. + + Metadata transport: + The returned ``dict`` is the offload payload that the + implementation PR will route to ``cuda_backend.cpp::init`` via the + AOTI ``preprocess`` path (see the "payload transport" open item + in the module docstring for the design constraint and the two + candidate channels: ``processed_bytes`` vs. ``NamedDataStore``). + + Args: + exported_program: an ``ExportedProgram`` produced by + ``torch.export.export``. Mutated in place: probe nodes are + inserted, consumer args are rewritten. + method_name: the method this pass is being applied to. Returned + verbatim in the payload so the runtime can validate which + method the bytes belong to. + pin_fqns: FQNs to mark as always-resident. Optional. The list + is propagated verbatim into the payload AND used by this pass + to exclude those FQNs from the schedule and the floor + calculation. Pinning an FQN that does not appear as a parameter + placeholder is a hard error. + + Returns: a ``dict`` with the keys defined at module scope (all + internal payload, not opt-in signals): + + - ``"version"``: ``int`` schema version (currently ``1``). Runtime + hard-fails on unknown version. + - ``"method_name"``: ``str``. Echoed for runtime validation. + - ``"schedule"``: ``list[str]`` of NON-PINNED parameter FQNs in + execution order. Drives the runtime cursor + prefetch (or is + obviated entirely if the implementation PR picks option (b) of + the "Schedule / cursor order" open item -- a probe-id arg). + Pinned FQNs do not appear here; their probes take a separate + fast-path in ``Session::serve`` that does not touch the cursor. + - ``"floor_bytes"``: ``int`` — minimum GPU byte budget for the + streaming portion of the working set (``max-over-consecutive- + kernel-pairs of (sum bytes K_i + K_{i+1}) + max single weight + size``), computed over the schedule above (i.e. excluding + pinned weights). The runtime asserts + ``(weight_offload_budget_mb << 20) - pinned_bytes`` covers + this; below-floor budgets hard-fail at init with the required + minimum spelled out. + - ``"pin_fqns"``: ``list[str]`` of FQNs the runtime keeps + resident. Empty if ``pin_fqns`` is unset. + + The opt-in signal is intended to be a separate + ``CompileSpec("weight_offload", b"1")`` emitted by ``CudaPartitioner`` + when wired -- the enable signal lives in exactly one place, the + compile spec, rather than being duplicated across compile spec + + payload. Neither the compile spec nor the partitioner kwarg exists + in this PR; see the EXPERIMENTAL banner at the top of this module. + + Not called by users. Not called by anything in this PR. + """ + ... diff --git a/backends/cuda/runtime/cuda_delegate_handle.h b/backends/cuda/runtime/cuda_delegate_handle.h index ee360531c47..9448960bbf7 100644 --- a/backends/cuda/runtime/cuda_delegate_handle.h +++ b/backends/cuda/runtime/cuda_delegate_handle.h @@ -168,6 +168,16 @@ struct CudaDelegateHandle : public aoti::AOTIDelegateHandle { // CUDA graph state (warmup, capture, replay, static buffers) CudaGraphState cuda_graph_state; + + // Weight offloading: the per-handle ``unique_ptr`` + // field lands with the implementation PR alongside ``weight_offload.cpp``, + // which provides the out-of-line ``Session::~Session()`` definition that + // ``unique_ptr``'s implicit destructor needs. Adding the field + // here in an API-surface-only PR would force every TU that includes this + // header into an unresolved-symbol link error against + // ``Session::~Session()``. See + // ``backends/cuda/runtime/weight_offload/weight_offload.h`` for the + // ownership model that explains why the Session lives per-handle. }; } // namespace cuda diff --git a/backends/cuda/runtime/weight_offload/prefetcher.h b/backends/cuda/runtime/weight_offload/prefetcher.h new file mode 100644 index 00000000000..6967fb21a02 --- /dev/null +++ b/backends/cuda/runtime/weight_offload/prefetcher.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// =========================================================================== +// EXPERIMENTAL -- NOT YET WIRED +// =========================================================================== +// Design document for the prefetcher Session::serve() will own once weight +// offloading is wired. No caller in this PR; no implementation file. See +// ``weight_offload.h`` (same directory) for the broader design and the +// list of open items that block wiring. +// =========================================================================== + +#include +#include +#include + +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { +namespace weight_offload { + +// --------------------------------------------------------------------------- +// Async prefetch protocol +// --------------------------------------------------------------------------- +// Implementation detail of ``weight_offload::Session::serve()``. Owned by +// the Session (one Prefetcher per DelegateHandle); not part of the +// user-facing API. This header is INTERNAL to the CUDA backend. +// +// Two streams in play: +// +// - **Compute stream**: borrowed from the owning Session (which got it +// from ``CudaDelegateHandle``). The prefetcher injects +// ``cudaStreamWaitEvent`` here but does not enqueue compute work. +// - **Copy stream**: owned by the prefetcher. Carries +// ``cudaMallocFromPoolAsync``, ``cudaMemcpyAsync`` (H2D), and any +// ``cudaFreeAsync`` for evictions handed off from the LRU. +// +// Each live allocation carries a ``cudaEvent_t`` recorded on the copy +// stream after its H2D completes. Consumer kernels on the compute stream +// must ``cudaStreamWaitEvent`` against that event before reading. +// +// Depth is one: each ``serve()`` call kicks off at most one +// ``opportunistic_prefetch`` for the next schedule entry. Deeper queues +// don't help in either the compute-bound regime (one-ahead already +// saturates the overlap budget) or the PCIe-bound regime (the copy stream +// is already serializing back-to-back transfers). + +// --------------------------------------------------------------------------- +// Eviction stream selection +// --------------------------------------------------------------------------- +// LRU evictions issue ``cudaFreeAsync`` on the stream that last touched +// the victim allocation. Two cases, one rule per case: +// +// 1. Consumed victim (compute stream last touched it). +// The owning kernel ran on the compute stream after its +// ``cudaStreamWaitEvent`` against the copy event. Issuing the free +// on the **compute stream** stream-orders it behind the consuming +// kernel, so the pool's cross-stream opportunistic reuse cannot +// hand the bytes to a new ``cudaMemcpyAsync`` until the compute +// stream catches up. Free-on-copy-stream would let the pool +// recycle bytes a still-running kernel is reading. +// +// 2. In-flight prefetched victim (no compute-stream work depends on +// it yet — its ``ready_event`` was recorded but no +// ``wait_on_compute_stream`` has been issued). +// Free-on-compute-stream is UNSAFE here: the compute stream has +// no pending work on the bytes, so ``cudaFreeAsync(compute_stream)`` +// queues behind nothing and the pool can recycle the bytes +// immediately, while the copy stream is still writing them. The +// eviction path must EITHER ``cudaStreamWaitEvent`` the copy +// event on the compute stream before issuing the +// ``cudaFreeAsync(compute_stream)``, OR issue the +// ``cudaFreeAsync`` on the **copy stream** so the free orders +// behind the in-flight H2D. Either is correct; copy-stream-free +// is simpler because it requires no extra event wait. +// +// The Prefetcher tracks per-LiveAllocation state (``consumed`` flag, +// flipped to true the first time ``wait_on_compute_stream`` runs for +// the FQN) so the eviction path can pick the right stream without +// extra bookkeeping in the caller. + +// --------------------------------------------------------------------------- +// Prefetcher +// --------------------------------------------------------------------------- +// Holds the copy stream + a map of per-FQN ``LiveAllocation`` records. +// Methods are single-threaded (see weight_offload.h thread-safety contract). +class Prefetcher { + public: + // Construct with the compute stream the owning Session is bound to and + // the Session-owned memory pool. Allocates the internal copy stream. + Prefetcher(cudaStream_t compute_stream, cudaMemPool_t pool); + ~Prefetcher(); + + Prefetcher(const Prefetcher&) = delete; + Prefetcher& operator=(const Prefetcher&) = delete; + + // ------------------------------------------------------------------------- + // Synchronous-on-the-CPU path: bring this FQN's bytes onto the GPU and + // record a ready event. Idempotent: if the allocation is already live + // (in flight or completed), returns the existing device pointer without + // re-issuing the copy. + // + // Side effects: + // - May trigger LRU eviction(s) if the budget is tight (frees issued + // on the compute stream per the rule above). + // - Issues ``cudaMallocFromPoolAsync`` on the copy stream. + // - Issues ``cudaMemcpyAsync`` (H2D) on the copy stream. + // - Records a ``cudaEvent_t`` on the copy stream after the copy. + // + // Hard-fails if the pool is exhausted even after evicting every + // evictable allocation (indicates the budget is below the floor — a + // bug, since ``Session::register_schedule`` should have caught it). + // ``Session::Config::pin_fqns`` entries are filtered out of the + // eviction candidate set by the Session before delegating here; the + // Prefetcher itself does not know about pinning. + ::executorch::runtime::Result ensure_ready(std::string_view fqn); + + // Issue ``cudaStreamWaitEvent`` on the compute stream against this + // FQN's ready event, so the next kernel sees the bytes. No-op if the + // event hasn't been recorded yet (the caller is responsible for + // calling ``ensure_ready`` first). Flips the LiveAllocation's + // ``consumed`` flag to true; subsequent ``release`` calls will + // issue ``cudaFreeAsync`` on the compute stream rather than the + // copy stream (see "Eviction stream selection" above). + void wait_on_compute_stream(std::string_view fqn); + + // ------------------------------------------------------------------------- + // Depth-1 best-effort prefetch. Same effects as ``ensure_ready`` except + // failures are logged and swallowed (the next ``serve()`` call will pay + // the cost synchronously). Returns ``Error::Ok`` even on allocation + // failure; only catastrophic errors surface. + ::executorch::runtime::Error opportunistic_prefetch(std::string_view fqn); + + // ------------------------------------------------------------------------- + // Eviction: free the allocation for this FQN. The stream the + // ``cudaFreeAsync`` is issued on is selected per the rule documented + // in "Eviction stream selection" above (compute-stream when the + // victim has been consumed; copy-stream when it's still in flight). + // Returns the freed byte count for budget bookkeeping. + ::executorch::runtime::Result release(std::string_view fqn); + + // ------------------------------------------------------------------------- + // Query: is there a live allocation for this FQN right now (in flight + // or completed)? Used by the LRU walker to pick eviction candidates. + bool is_live(std::string_view fqn) const; + + // Total currently-allocated bytes across all live FQNs. Used to gate + // new allocations against the budget. + uint64_t bytes_in_use() const; + + private: + // Per-live-allocation state. + // device_ptr : pool slot returned by cudaMallocFromPoolAsync + // nbytes : size of the H2D copy + // ready_event : recorded on copy stream after H2D completes + // consumed : false initially; flipped to true on the first + // ``wait_on_compute_stream`` for this FQN. + // Determines which stream ``release`` issues + // ``cudaFreeAsync`` on (see "Eviction stream + // selection" above). + // lru_tick : monotonic counter set on last access (for LRU) + struct LiveAllocation; + + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace weight_offload +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/weight_offload/probe_op.h b/backends/cuda/runtime/weight_offload/probe_op.h new file mode 100644 index 00000000000..df733494b94 --- /dev/null +++ b/backends/cuda/runtime/weight_offload/probe_op.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// =========================================================================== +// EXPERIMENTAL -- NOT YET WIRED +// =========================================================================== +// Design document for the AOTI c-shim that ``executorch_weight_offload::probe`` +// will lower to. No implementation, no caller, and the export-side probe +// op is itself not yet preserved through inductor lowering (see the +// "Probe op preservation" open item in ``weight_offload.h``). Adding +// this symbol to the runtime today would not be reached by any AOTI +// wrapper. +// =========================================================================== + +#include +#include + +extern "C" { + +// AOTI c-shim for ``executorch_weight_offload::probe``. +// +// AOTI's name mangling convention is ``aoti_torch__`` +// (see PyTorch's ``cpp_wrapper_cpu.py:1638-1649``). For +// ``executorch_weight_offload::probe(w: Tensor) -> Tensor`` on CUDA, the +// generated ``wrapper.cpp`` emits a direct call to +// ``aoti_torch_cuda_probe(input_handle, &output_handle)``. The C signature +// of this symbol is the string the CUDA backend hands AOTI via the +// ``aot_inductor.custom_ops_to_c_shims`` config (see +// ``backends/cuda/cuda_backend.py``). +// +// At runtime, the symbol must be globally visible from the host process so +// the dynamically-loaded AOTI ``.so`` can resolve it. The CUDA backend's +// ``platform.cpp`` dlopen handshake promotes ``libaoti_cuda_shims.so``'s +// symbols to global on first ``load_library`` call. +// +// Implementation (in ``probe_op.cpp``) is a thin forwarder: +// 1. Cast the input handle to ``SlimTensor*``. +// 2. ``weight_offload::ProbeRegistry::instance().lookup(input->data_ptr())`` +// to find the owning ``Session*``. Unknown pointer is a hard +// failure — it means the pass emitted a probe for a weight no +// Session bound, or two Sessions collided. +// 3. Call ``session->serve(input)``. +// 4. Write the returned ``SlimTensor*`` to ``*output`` as an +// ``AtenTensorHandle``. +// +// The ProbeRegistry indirection is what keeps the runtime per-Session: +// the C ABI only sees the input tensor, so the registry is the only +// process-global state weight offloading requires. +::executorch::runtime::Error aoti_torch_cuda_probe( + ::executorch::backends::aoti::AtenTensorHandle input, + ::executorch::backends::aoti::AtenTensorHandle* output); + +} // extern "C" diff --git a/backends/cuda/runtime/weight_offload/weight_offload.h b/backends/cuda/runtime/weight_offload/weight_offload.h new file mode 100644 index 00000000000..e8f9d6c896d --- /dev/null +++ b/backends/cuda/runtime/weight_offload/weight_offload.h @@ -0,0 +1,811 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// =========================================================================== +// EXPERIMENTAL -- NOT YET WIRED +// =========================================================================== +// This header is a DESIGN DOCUMENT. The Session/WeightCatalog/ProbeRegistry +// types declared below have no implementation in this PR, no caller in the +// CUDA backend, and no field on ``CudaDelegateHandle``. The two public +// knobs described in the Scope block (``CudaPartitioner(weight_offload=...)`` +// and the ``weight_offload_budget_mb`` runtime spec) are also NOT YET +// WIRED -- the partitioner kwargs are intentionally not exposed in this +// PR (see ``cuda_partitioner.py``), and ``cuda_backend.cpp`` does not +// read the runtime spec. +// +// Two open questions block wiring and are tracked here so reviewers don't +// have to re-discover them when the implementation PR lands: +// +// * Probe op preservation (export half). The pass inserts an identity +// ``executorch_weight_offload::probe(w)`` custom op before every +// parameter consumer. ``custom_op`` with ``mutates_args=()`` is NOT +// sufficient to guarantee survival through inductor's CSE + fusion +// passes -- an identity op with no side effect is a textbook DCE +// target. The implementation PR must either (a) give probe explicit +// barrier semantics (e.g. ``mutates_args={"w"}`` plus a side-effecting +// marker that doesn't trip torch.export's parameter-output validation, +// or a ``torch.ops.aten`` placeholder the lowering rewrites), AND +// (b) ship a regression test that asserts the lowered AOTI wrapper +// contains one ``aoti_torch_cuda_probe`` call per +// ``(consumer, weight)`` pair. Without both, the runtime contract +// ("every parameter consumer's read goes through ``Session::serve``") +// is unenforced and offload silently degrades to eager loading. +// +// * AOTI blob layout (runtime half). ``WeightCatalog::build`` needs +// each constant's offset within the AOTI weights blob to copy bytes +// into pinned host memory, plus its ``data_size``, ``dtype``, +// ``ndim``, and ``shape`` to size the copy and construct SlimTensor +// views over the live device allocations later. AOTI exposes NONE +// of these getters today (no ``get_constant_blob_offset``, no +// ``get_constant_data_size``, no ``get_constant_dtype/ndim/shape``). +// Two options for the implementation PR: (a) land the upstream +// AOTI shims, add typedefs to ``aoti_delegate_handle.h``, and +// dlsym them from ``cuda_backend.cpp``; (b) compute the metadata +// at export time and serialize it into the offload payload +// alongside the schedule / floor / pin set. A first-constant +// fingerprint check at init is NOT a substitute -- it can only +// catch alignment drift on ``idx=0``; it cannot detect e.g. a +// reordered or padded interior constant. The fingerprint idea was +// in an earlier draft of this header and has been removed; pick +// one of (a)/(b) when wiring. +// +// * Schedule / cursor order. ``Session::serve`` hard-fails on a +// cursor mismatch (see ``register_schedule`` and the hot-path hard- +// fail catalog), and that contract depends on the execution-order +// FQN list the pass records matching the order the lowered AOTI +// wrapper actually invokes ``aoti_torch_cuda_probe`` in. The pass +// observes the ExportedProgram BEFORE ``AotiBackend.preprocess`` +// runs custom passes / decomposition / inductor lowering (see +// ``backends/aoti/aoti_backend.py:206``), all of which can reorder +// or duplicate parameter reads. The two halves of the contract can +// drift silently -- which the runtime would then catch as a hard +// fail on the first wrong-order probe, but only after a debugging +// session traceable to whatever pass mutated the order. Two options +// for the implementation PR: (a) generate the schedule from the +// final wrapper / post-lowering order (extract probe call order +// from inductor's output, not from the pre-lowering graph); or +// (b) make probe self-identifying -- add an explicit +// ``int64_t probe_id`` (or FQN ``string``) argument so each +// ``serve()`` call knows which weight it's for without a cursor. +// (b) removes the cursor mismatch class entirely at the cost of a +// wider op signature. +// +// Everything below is preserved as design intent for the implementation +// PR. Do NOT include this header from active code paths. +// =========================================================================== + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { +namespace weight_offload { + +// --------------------------------------------------------------------------- +// Scope +// --------------------------------------------------------------------------- +// This header is INTERNAL to the CUDA backend. It is not a user-facing API. +// +// The public contract for weight offloading is exactly two knobs: +// 1. Export-time: ``CudaPartitioner(compile_spec=[...], +// weight_offload=True, weight_offload_pin_fqns=[...])`` — the +// partitioner is the single authoritative source for the pinning +// policy (the pass needs it to compute the floor correctly; see +// the schedule/floor contract on ``Session::register_schedule`` +// below). The ``compile_spec`` positional carries the method name +// via ``CudaBackend.generate_method_name_compile_spec(...)``; the +// pass uses it to namespace the payload so prefill and decode +// don't collide. (Transport channel: see the OPEN item below.) +// 2. Load-time: backend runtime spec ``weight_offload_budget_mb`` +// (int, megabytes) read by ``cuda_backend.cpp`` from +// ``BackendInitContext::get_runtime_spec("weight_offload_budget_mb")``. +// The runtime takes NO pin-set option; the pin set is fully +// determined at export time and propagated inside the payload +// described below. +// +// Unit is megabytes (not bytes) because ``BackendOptions`` today +// stores ``int`` (~31-bit; see ``runtime/backend/options.h``) and +// ``BackendInitContext::get_runtime_spec`` only supports ``bool``, +// ``int``, and ``const char*``, which overflows on any realistic +// LLM budget when expressed in bytes. Sub-MB budgets are +// meaningless for transformer-scale weights anyway. Switch to +// ``weight_offload_budget_bytes`` (int64) the moment +// ``BackendOptions`` grows 64-bit support; the payload schema +// version below gates the upgrade. +// +// Do NOT promote this knob to ``Module(pte, +// weight_offload_budget_mb=512)``. +// ``Module`` is backend-agnostic and per-program; the budget is per +// backend AND per method (prefill and decode get separate Sessions +// with separate budgets). It belongs on the runtime-spec channel +// where it can be set per-(backend, method) without leaking +// CUDA-specific knobs into the cross-backend Module API. +// +// The opt-in signal carried from the partitioner to the runtime is the +// single ``CompileSpec("weight_offload", b"1")`` the partitioner emits; +// ``cuda_backend.cpp::init`` checks that compile spec to decide whether +// to parse the offload payload and construct a ``Session`` for the +// handle. The payload is internal data — not an opt-in signal. The +// runtime reads it only after the compile spec confirms offload is +// enabled. +// +// Metadata transport -- OPEN: +// The pass needs to run from inside ``AotiBackend.preprocess`` (the +// ``ExportedProgram`` partitioner contract at +// ``exir/backend/partitioner.py:83`` forbids the partitioner itself +// from mutating the program). Two candidate channels for handing the +// schedule / floor / pin_fqns payload across to the runtime, both +// accessible from ``preprocess`` and neither yet wired: +// (a) Serialize into the partition's ``processed_bytes`` +// (currently ``b""`` for the AOTI backend; see +// ``backends/aoti/aoti_backend.py``). Tiny, partition-local, +// no cross-method dedup needed. +// (b) Attach a per-method ``NamedDataStore`` entry from inside +// ``preprocess`` (where ``AotiBackend`` already writes +// ``_so_blob`` and ``_weights_blob``). Same lifetime, slightly +// more general. +// Pick one in the implementation PR; the schema below is channel- +// agnostic. +// +// Payload schema (one entry per method, keyed by ``method_name`` so +// prefill and decode coexist in the same .pte): +// version uint64 schema version. Bumped whenever any field +// below changes shape. ``init`` hard-fails on missing +// or unknown version with the expected range spelled +// out — schema drift surfaces loudly instead of +// silently mis-parsing. +// schedule list[str] of NON-PINNED FQNs in execution order; +// drives the cursor + prefetch. +// floor_bytes uint64 minimum streaming-pool byte budget. +// ``init`` asserts ``(budget_mb << 20) - pinned_bytes +// >= floor_bytes``; below-floor hard-fails at init. +// pin_fqns list[str] of FQNs the runtime keeps resident. +// Empty if no pinning was requested. +// +// AOTI constant-folding contract: +// The pass operates on ``ExportedProgram`` parameter placeholders. +// AOTI has knobs that can drop parameters out of the container by +// folding them into kernels (the ``ExtractConstantsMap`` / +// ``get_constant_name`` sequence no longer enumerates them). The +// pass cannot insert probes for parameters it cannot see, so those +// constants would be loaded eagerly through the normal blob path at +// runtime — silently defeating offload and reintroducing the OOM +// this feature exists to prevent. +// +// The contract: exports that enable weight offload must NOT have +// AOTI fold parameter constants. The exact ``torch._inductor.config`` +// knob and its required value is verified in the implementation PR +// and asserted by the pass at export time (placeholder count vs. +// container catalog count). Defense in depth: the runtime +// hard-fails again at ``Session::bind_placeholder_constants`` if +// any catalog FQN is missing a probe binding, so the two halves +// cannot drift silently. +// +// Everything below — Session, WeightCatalog, ProbeRegistry, the +// bind/register/serve sequence — is implementation detail. Callers outside +// ``backends/cuda/runtime/`` should not include this header. + +// --------------------------------------------------------------------------- +// Hard-fail catalog +// --------------------------------------------------------------------------- +// Every condition below is a HARD FAIL at the named lifecycle step, +// with a diagnostic identifying which entry was violated. None of them +// silently downgrade to the non-offload path: a silent fallback would +// let the same exported artifact stream weights on one build and +// eagerly load them on another, which is exactly the surprise this +// opt-in is designed to prevent. +// +// At export (_apply_weight_offload): +// * ``pin_fqns`` references an FQN that is not a parameter +// placeholder. +// * AOTI folded one or more parameter constants out of the +// container (placeholder count vs. catalog count mismatch); +// see the constant-folding contract above for the why. +// +// At init (CudaBackend::init when ``weight_offload`` CompileSpec is +// present): +// * Any AOTI symbol in ``AOTIFunctions`` is unresolved on the +// method's ``.so``. +// * ``weight_offload_budget_mb`` runtime spec is unset (no +// auto-sizing default — loud and early is debuggable). +// * The offload payload is absent on whichever transport channel +// the implementation PR picks, or names an unknown schema version. +// * Any of the schedule / floor / pin_fqns fields in the payload +// is missing or malformed. +// * A pin_fqn is not present in the catalog. +// * ``(budget_mb << 20) - pinned_bytes < floor_bytes``. +// * Both ``weight_offload`` and ``enable_cuda_graph_for_method`` +// are set for the same method (CUDA-graph Replay bypasses +// AOTI's run() and probes never fire). +// +// At hot path (Session::serve, via the probe c-shim): +// * ``ProbeRegistry::lookup`` returns ``Error::NotFound`` for a +// probe input (a binding was missed at init, or two Sessions +// collided in the registry). +// * The streaming-path FQN does not match the schedule's current +// cursor (pass-vs-AOTI emission-order drift -- see the +// "Schedule / cursor order" OPEN item at the top of this file +// for why this hard fail is possible and the two ways to remove +// the class -- or stale pin set whose excluded FQNs no longer +// match the schedule). + +// --------------------------------------------------------------------------- +// Ownership model +// --------------------------------------------------------------------------- +// ProbeRegistry process-global forced by the C ABI of +// ``aoti_torch_cuda_probe`` (input +// tensor only — must dispatch from +// dummy ``data_ptr()`` to the owning +// Session). Tiny: dispatch table only. +// +// WeightCatalog per-Program immutable after construction: host +// pinned mirror of weight bytes + +// per-FQN metadata. Sessions hold a +// ``shared_ptr``, so multiple methods +// loaded from one .pte (e.g. prefill + +// decode) dedup the host bytes. +// +// Session per-DelegateHandle owns mutable execution state: +// GPU pool, copy stream, schedule +// cursor, live allocations, stats. +// Per-method budget — each Session has +// its own. Lifetime tied to +// ``CudaDelegateHandle::weight_offload_session``. +// +// Why per-handle Session, not per-Program: prefill and decode have very +// different working sets, and forcing them to share a single GPU pool +// means whichever runs first locks in eviction patterns the other has to +// fight. Per-method budget is the right contract; deduping the host +// mirror is the only sharing that's clearly correct. + +// --------------------------------------------------------------------------- +// Thread-safety contract +// --------------------------------------------------------------------------- +// A single Session is single-threaded on the hot path. The init-time APIs +// (``Session::bind_placeholder_constants``, ``Session::register_schedule``) +// may be called concurrently for different Sessions; CudaBackend already +// holds ``shared_constants_mutex_`` when initializing handles, so +// per-Session construction does not race. +// +// The hot-path APIs (``Session::on_execute_begin``, ``Session::serve``, +// ``Session::stats``) assume the calling thread is the only one currently +// executing inside that Session. Concurrent ``execute()`` calls on +// different methods that each have their own Session are safe; concurrent +// calls into the same Session are undefined behavior. This matches the +// non-atomic ``SharedPtr`` assumption baked into the SlimTensor layer +// (see ``backends/aoti/slim/util/shared_ptr.h``). +// +// ``ProbeRegistry`` is internally synchronized for lookup/registration/ +// unregistration (rare-write, frequent-read). + +// --------------------------------------------------------------------------- +// AOTI function pointers required by WeightCatalog + Session +// --------------------------------------------------------------------------- +// dlsym'd per-method from each .so by ``cuda_backend.cpp``; passed in so +// the weight_offload layer never touches dlfcn directly. +// +// ``WeightCatalog::build`` reads per-constant name + FQN via these +// pointers. Per-constant ``data_size``, ``dtype``, ``shape``, and blob +// offset are open items -- see the WeightCatalog::build docstring and +// the AOTI-blob-layout OPEN item in the EXPERIMENTAL banner at the top +// of this file. The actual constant bytes come from a separate +// ``weight_blob`` parameter — AOTI does NOT expose the constant bytes +// through any container API when constructed with +// ``include_weights=false`` (which is the standard case for the CUDA +// backend; see ``model_base.h:609-611``). The CUDA backend sources the +// blob the same way it does for the non-offload path: from the .pte +// data, currently via ``update_constants_from_blob``. +// +// Note on constant folding: see the constant-folding contract in the +// Scope block. ``WeightCatalog::build`` here can assume one container +// entry per parameter placeholder the pass saw; any deviation is a +// hard fail at init. +// +// ``Session::bind_placeholder_constants`` uses +// ``update_user_managed_constant_buffer_pairs`` to install 1-byte GPU +// dummies (one per non-pinned constant; a real pool allocation for each +// pinned constant) in place of the container's default constant slots. +// Note: ``WeightCatalog::build`` additionally needs per-constant +// ``data_size``, ``dtype``, ``ndim``, ``shape``, and blob offset. +// AOTI exposes NONE of those shims today. Listing them here as +// ``nullptr``-defaulted fields would be misleading -- it suggests the +// loader in ``cuda_backend.cpp`` knows how to populate them, when in +// fact none of those typedefs exist in ``aoti_delegate_handle.h`` +// either. The implementation PR must either land the upstream shims +// (and add fields here once the loader actually dlsyms them) or +// serialize the metadata into the offload payload at export time. +// See the AOTI-blob-layout OPEN item in the EXPERIMENTAL banner at +// the top of this file for the two options. +struct AOTIFunctions { + ::executorch::backends::aoti::AOTInductorModelContainerGetNumConstantsFunc + get_num_constants{nullptr}; + ::executorch::backends::aoti::AOTInductorModelContainerGetConstantNameFunc + get_constant_name{nullptr}; + ::executorch::backends::aoti:: + AOTInductorModelContainerGetConstantOriginalFQNFunc + get_constant_original_fqn{nullptr}; + ::executorch::backends::aoti:: + AOTInductorModelContainerUpdateUserManagedConstantBufferPairsFunc + update_user_managed_constant_buffer_pairs{nullptr}; +}; + +// --------------------------------------------------------------------------- +// Runtime statistics +// --------------------------------------------------------------------------- +// Per-Session cumulative counters since Session construction (or last +// ``reset_stats()``). All counters are incremented by ``Session::serve`` +// (the streaming path) and ``Prefetcher`` internals; pinned-FQN serves do +// NOT touch any of these counters — pinned reads bypass the pool entirely +// and aren't part of the streaming workload these stats describe. +// +// Implementer: every increment is a single uint64_t bump under the +// single-threaded-per-Session invariant (see thread-safety contract); no +// atomics needed. ``reset_stats()`` zeroes every field; in-flight +// prefetches whose ``prefetch_attempted++`` has already fired but whose +// ``prefetch_succeeded++`` has not are lost (acceptable: the field is a +// counter, not a gauge). +struct SessionStats { + // Number of ``serve()`` streaming-path calls that found the FQN's + // bytes already live in the pool — whether warmed by an earlier + // synchronous serve, a successful ``opportunistic_prefetch``, or + // survival across an ``on_execute_begin`` boundary. Counts the + // event of "no synchronous H2D was needed," NOT "no eviction was + // needed" — the two diverge when prefetch lands in time. + uint64_t pool_hits; + + // Number of ``serve()`` streaming-path calls that had to issue a + // synchronous H2D because no live allocation existed (or the + // allocation existed but its ready event hadn't fired yet, forcing + // a wait that the hit path would have avoided). Sum ``pool_hits + + // pool_misses`` equals the total streaming-path ``serve()`` calls. + uint64_t pool_misses; + + // Number of LRU evictions performed across the lifetime of this + // Session — counted once per victim ``cudaFreeAsync``, regardless + // of whether the eviction was triggered by ``ensure_ready`` or + // ``opportunistic_prefetch``. A single ``serve()`` call that needs + // room for two victims contributes 2 here. + uint64_t evictions; + + // Total bytes issued to ``cudaMemcpyAsync`` H2D across the pool's + // copy stream — counted at issue time (event-record), not at + // event-complete. Includes both synchronous-on-the-hot-path copies + // and opportunistic prefetch copies. Does not include the initial + // catalog build's host-side pinning copy (that's per-Program, not + // per-Session). + uint64_t bytes_h2d_copied; + + // Number of ``opportunistic_prefetch`` calls attempted by + // ``serve()`` after advancing the cursor — incremented BEFORE the + // prefetch is issued, so ``prefetch_attempted - prefetch_succeeded`` + // is the count of swallowed-error prefetches (allocation failed, + // budget tight, etc.; see ``prefetcher.h:opportunistic_prefetch``). + uint64_t prefetch_attempted; + + // Number of ``opportunistic_prefetch`` calls that successfully + // issued the H2D — incremented after ``cudaMemcpyAsync`` is queued + // on the copy stream, not after the event completes. A prefetch + // that issues but is never consumed (e.g. evicted before the next + // serve) still counts as succeeded here; the consumption signal + // shows up in ``pool_hits``. + uint64_t prefetch_succeeded; +}; + +// --------------------------------------------------------------------------- +// WeightCatalog — per-Program immutable host mirror +// --------------------------------------------------------------------------- +// Built once per ``Program`` from the .pte's weight blob plus the AOTI +// container's per-constant metadata. Owns (or references — see below) the +// pinned host bytes plus per-FQN dtype/sizes/nbytes. Read-only after +// construction; safe to share across Sessions via ``shared_ptr``. +// +// CudaBackend keeps a process-global ``unordered_map>`` so methods loaded from the same .pte reuse +// the same catalog (prefill + decode dedup the host bytes). The +// ``ProgramKey`` derivation is a CudaBackend implementation detail +// (likely the container's ``so_path`` or an AOTI-blob fingerprint). +class WeightCatalog { + public: + ~WeightCatalog(); + WeightCatalog(const WeightCatalog&) = delete; + WeightCatalog& operator=(const WeightCatalog&) = delete; + + // Build a catalog. The bytes come from ``weight_blob`` (the same blob + // the CUDA backend would otherwise pass to + // ``update_constants_from_blob``); the catalog needs per-constant + // dtype / ndim / shape / nbytes / blob offset to slice the blob and to + // construct SlimTensor views over the live device allocations later. + // + // OPEN: per-constant blob offset. + // ``WeightCatalog::build`` must know each constant's offset within + // ``weight_blob`` to copy bytes out into pinned host memory. AOTI + // does NOT expose ``get_constant_blob_offset`` today. The + // implementation PR must pick one of two paths: + // + // (a) Add a ``get_constant_blob_offset`` shim upstream in + // PyTorch's AOTI codegen, dlsym it from ``cuda_backend.cpp``, + // and read offsets at init. + // (b) Compute offsets at export time (the pass already walks the + // placeholders in order) and serialize them into the + // partition payload alongside the schedule / floor / pin set. + // + // Earlier drafts of this header proposed a "compute offsets from + // declaration order + fixed alignment, verify with a first-constant + // fingerprint" scheme. That scheme has been removed: a + // first-constant fingerprint can only catch alignment drift on + // ``idx=0``, not interior reordering or per-constant padding + // changes, and AOTI does not contractually guarantee either order + // or alignment. Either pick up (a) or (b); do not reintroduce the + // fingerprint shortcut. + // + // OPEN: per-constant dtype / ndim / shape. + // Same gap, same two options -- AOTI doesn't expose these getters + // today. (a) add shims upstream, or (b) serialize at export time. + // See the AOTIFunctions block above for the corresponding nullptr + // fields that were removed. + // + // The catalog ALWAYS owns the pinned host bytes: ``build`` copies + // out of ``weight_blob`` into a freshly-allocated ``cudaHostAlloc``'d + // region keyed by FQN, and the destructor ``cudaFreeHost``s it. The + // caller is free to drop ``weight_blob`` the moment ``build`` + // returns. Always-own keeps lifetime inside the type contract — a + // zero-copy ``cudaHostRegister``-the-input alternative was rejected + // because it leaks ownership back to the caller through nothing + // stronger than a docstring, which is exactly the kind of contract + // we don't want at an internal boundary the rest of the runtime + // relies on. + // + // Preconditions: + // - The container has ``include_weights=false`` (standard for the + // CUDA backend). + // - All ``AOTIFunctions`` fields are non-null. + // - ``weight_blob`` and ``weight_blob_size`` describe the AOTI + // constants blob in its on-disk layout (same format + // ``update_constants_from_blob`` parses). + // - The implementation has a way to obtain per-constant offset and + // dtype/shape (see the two OPEN items above). + static ::executorch::runtime::Result> build( + AOTIFunctions fns, + ::executorch::backends::aoti::AOTInductorModelContainerHandle container, + const uint8_t* weight_blob, + uint64_t weight_blob_size); + + // Per-FQN view: host pointer, size, dtype/shape. Strides are not + // stored — AOTI constants are always contiguous, so callers compute + // strides from sizes when constructing the SlimTensor in + // ``Session::serve``. + struct WeightView { + const void* host_ptr; + uint64_t nbytes; + int32_t dtype; + std::vector sizes; + }; + + // Returns nullptr if the FQN is not in the catalog. + const WeightView* find(std::string_view fqn) const; + + // Iteration for ``Session::bind_placeholder_constants`` (it needs the + // list of (fqn, nbytes) pairs to size each dummy correctly) and for + // ``Session::register_schedule``'s pin-set validation. + size_t size() const; + const WeightView& at(size_t idx) const; + std::string_view fqn_at(size_t idx) const; + + private: + struct Impl; + std::unique_ptr impl_; + explicit WeightCatalog(std::unique_ptr); +}; + +// --------------------------------------------------------------------------- +// Session — per-DelegateHandle execution context +// --------------------------------------------------------------------------- +// Owned by ``CudaDelegateHandle::weight_offload_session``. Destructor +// frees the pool, removes this Session's entries from the ProbeRegistry, +// and drops the shared_ptr to the catalog (the catalog itself is freed +// when the last Session drops its reference). +// +// Note: the destructor does NOT call ``AOTInductorModelContainerDelete`` +// on the container — that API is intentionally never called by +// CudaBackend (see ``cuda_backend.cpp:980-987``). The container's +// references to the dummy pointers go dangling, but nothing reads them +// after Session destruction. +// +// Lifecycle (call order is part of the contract): +// +// create() [init, once per Session] +// │ +// ▼ +// bind_placeholder_constants() [init, once per Session] +// │ ──► registers (dummy_data_ptr → Session, FQN) with the +// │ ProbeRegistry so serve() can later resolve probes +// ▼ +// register_schedule() [init, once per Session] +// │ ──► validates budget vs. floor; init fails here if the +// │ runtime budget cannot cover the schedule +// ▼ +// ┌─►on_execute_begin() [hot path, once per execute()] +// │ │ ──► resets the cursor to 0; live allocations PERSIST +// │ ▼ +// │ serve() serve() serve() ... [hot path, called from c-shim] +// │ │ +// │ ▼ +// └──(next execute()) +// │ +// ▼ +// ~Session() [teardown; unregisters from +// ProbeRegistry BEFORE pool free] +// +// Calling order rules — violations are bugs in the caller, NOT +// runtime user errors. The contract is enforced by CudaBackend; the +// Session itself does not guard against misordering (would add hot- +// path branches for a class of bug only the backend can produce): +// +// - ``bind_placeholder_constants`` MUST be called before any +// ``execute()`` reaches ``serve()``. If a probe fires before +// binding, ``ProbeRegistry::lookup`` returns ``Error::NotFound`` +// and the c-shim hard-fails — surfacing the misorder loudly, +// not silently. +// - ``register_schedule`` MUST be called before the first +// ``on_execute_begin``. Without it, the cursor has no schedule +// to validate against and the first streaming ``serve()`` would +// hard-fail "schedule empty." +// - ``on_execute_begin`` MUST be called before every ``execute()`` +// run that will route probes through this Session. CudaBackend +// calls it from its ``execute()`` entry before invoking AOTI's +// ``run()``. +// +// Live-allocation persistence across runs (warm cache): +// ``on_execute_begin`` resets ONLY the schedule cursor. All live +// pool allocations from previous runs survive untouched — that +// persistence is what makes the H2D/compute overlap story work +// under steady-state inference (each run starts hot, and only the +// first run pays the full streaming cost). Stats counters +// (``SessionStats``) also persist across runs; use +// ``reset_stats()`` if a fresh window is needed. +class Session { + public: + // Configuration for one method. + // budget_bytes : pool cap in bytes. ``CudaBackend::init`` derives + // this from the runtime spec + // ``weight_offload_budget_mb`` (int megabytes, + // fetched via + // ``BackendInitContext::get_runtime_spec``) + // as ``static_cast(mb) << 20``; + // megabytes is the wire unit because + // ``BackendOptions`` currently stores ``int`` + // (~31-bit), which overflows on realistic LLM + // budgets when expressed in bytes. The Session + // itself works in bytes — no rounding happens + // past the init boundary. Below the (pin- + // adjusted) schedule floor → init fails with + // the required minimum spelled out. + // compute_stream : the CudaDelegateHandle's stream. Session does + // not own it. + // pin_fqns : FQNs to keep resident — never evicted, never + // streamed. Pinned bytes are subtracted from + // ``budget_bytes`` before the floor check. + // + // The pin set is SET AT EXPORT TIME by the + // partitioner and reaches the runtime inside the + // payload described in the metadata-transport + // OPEN item at the top of this header. + // ``CudaBackend::init`` parses the payload and + // passes the list here verbatim — there is no + // user-facing runtime knob for pinning. The pin + // set affects floor correctness (the export-time + // floor is computed assuming these FQNs do not + // stream), so the runtime is not free to override + // it. + // + // The pass also excludes pinned FQNs from the + // schedule (see ``register_schedule``), so + // ``serve`` for a pinned FQN is a fast-path + // resident lookup with no cursor advance — it + // could not advance the cursor anyway since + // pinned FQNs are not in the cursor's sequence. + // method_name : for stats + error messages only. + struct Config { + uint64_t budget_bytes{0}; + cudaStream_t compute_stream{nullptr}; + std::vector pin_fqns; + std::string method_name; + }; + + // Construct. Allocates the pool, the copy stream, and the live-allocation + // table. Does NOT touch the AOTI container — call + // ``bind_placeholder_constants`` next. + static ::executorch::runtime::Result> create( + Config cfg, + std::shared_ptr catalog); + + ~Session(); + Session(const Session&) = delete; + Session& operator=(const Session&) = delete; + + // Walks the container's constants, allocates one 1-byte GPU dummy per + // non-pinned constant (a real pool allocation for each pinned constant), + // calls ``update_user_managed_constant_buffer_pairs(use_inactive=false, + // validate_full_update=true)``, and registers each + // ``(dummy_data_ptr -> this Session, FQN)`` entry with the + // ProbeRegistry so ``serve()`` can resolve the FQN at hot-path time. + // + // Preconditions: + // - ``container`` was constructed with ``include_weights=false`` + // (standard AOTI-package case for the CUDA backend; see + // ``model_base.h:609-611``). + // - All ``AOTIFunctions`` fields are non-null. The + // ``weight_offload`` CompileSpec is the contract that those + // symbols are resolvable; if the AOTI ``.so`` this handle dlopened + // does not export one of them, ``CudaBackend::init`` MUST hard-fail + // with the missing symbol named, not silently fall back to the + // non-offload path. A silent downgrade would let the same exported + // artifact stream weights on one build and eagerly load them on + // another, which is the surprise this opt-in is meant to prevent. + // - The catalog passed to ``create`` covers every FQN in the + // container; mismatch returns ``Error::InvalidArgument``. The + // symmetric case — a catalog FQN with no corresponding container + // constant — indicates AOTI folded a constant into a kernel + // (constant folding was not disabled at export); also hard-fails + // here, naming the folded FQN. + ::executorch::runtime::Error bind_placeholder_constants( + AOTIFunctions fns, + ::executorch::backends::aoti::AOTInductorModelContainerHandle container); + + // Records the streaming schedule + the minimum GPU byte budget + // computed at export time (both delivered inside the offload payload; + // see the metadata-transport OPEN item at the top of this header for + // the channel choices). + // + // ``schedule`` is the execution-order FQN list of NON-PINNED weights + // only. Pinned FQNs are excluded by the pass; the cursor only + // advances through this list, and an attempt to ``serve`` an FQN that + // is neither pinned nor at the current cursor position is a hard + // failure. See the "Schedule / cursor order" OPEN item at the top of + // this header for the export-vs-AOTI emission-order question this + // contract depends on. + // + // ``floor_bytes`` is the minimum GPU budget for the streaming portion + // of the working set (``max-over-consecutive-kernel-pairs of (sum + // bytes K_i + K_{i+1}) + max single weight size``), excluding pinned + // weights. + // + // Validation: ``cfg.budget_bytes - pinned_bytes >= floor_bytes``, + // where ``pinned_bytes`` is the sum of catalog sizes for + // ``cfg.pin_fqns``. A below-floor budget causes a hard fail with the + // required minimum spelled out: loud-and-early is debuggable; silent + // slowdown via event-aware eviction is not. Pinning an FQN that is + // not in the catalog is also a hard failure (typo / stale pin list). + // + // Idempotent for the same ``(schedule, floor_bytes)``. Conflicting + // re-registration returns ``Error::InvalidArgument``. + ::executorch::runtime::Error register_schedule( + std::vector schedule, + uint64_t floor_bytes); + + // Called by ``CudaBackend::execute`` before AOTI's ``run()``. Resets + // ONLY the schedule cursor to 0 (so a previous mid-run error doesn't + // leave the cursor desynced). Live pool allocations from previous + // runs are NOT touched — they persist as a warm cache, which is what + // makes steady-state inference cheap. See the "Live-allocation + // persistence across runs" note in the Session class header above. + void on_execute_begin(); + + // Hot path -- serve one probe call. + // + // 1. Resolve ``input->data_ptr()`` to an FQN via this Session's + // ProbeRegistry entries. + // 2. If the FQN is in ``cfg.pin_fqns``: return a SlimTensor view of + // the resident pinned allocation directly. No copy, no event + // wait, no cursor touch — pinned FQNs are not in the schedule. + // 3. Otherwise (streaming path): assert the resolved FQN matches + // the schedule's current cursor. Mismatch is a hard failure + // (bug in the pass, in AOTI's emission order, or a stale pin + // list whose excluded FQNs no longer match the schedule). + // 4. Ensure the pool has a live allocation for this FQN. If not, + // ``cudaFreeAsync`` LRU victims on the compute stream (pinned + // allocations excluded from the candidate set) until there's + // room, ``cudaMallocFromPoolAsync`` on the copy stream, and + // ``cudaMemcpyAsync`` from the catalog's host mirror. Record + // the ready event on the copy stream. + // 5. Issue ``cudaStreamWaitEvent`` so the next kernel on the + // compute stream sees the bytes. + // 6. Advance the cursor; opportunistically prefetch the next + // schedule entry's weight if it doesn't already have a live + // allocation. + // 7. Return a freshly-allocated ``SlimTensor*`` wrapping the live + // allocation (dtype/sizes from the catalog; strides computed + // as contiguous). + // + // AOTI's ``RAIIAtenTensorHandle`` deletes the returned ``SlimTensor`` + // on scope exit; do NOT cache the returned handle across calls + // (caching causes use-after-free when the next call enters with a + // stale handle pointing at freed memory). + ::executorch::runtime::Result<::executorch::backends::aoti::slim::SlimTensor*> + serve(::executorch::backends::aoti::slim::SlimTensor* input); + + // Diagnostics. + SessionStats stats() const; + void reset_stats(); + + private: + struct Impl; + std::unique_ptr impl_; + explicit Session(std::unique_ptr); +}; + +// --------------------------------------------------------------------------- +// ProbeRegistry — process-global dispatch table +// --------------------------------------------------------------------------- +// The only truly global state. Forced by the C ABI of +// ``aoti_torch_cuda_probe``: the probe c-shim only receives the input +// tensor handle, so it has to look up the owning Session from the +// tensor's ``data_ptr()``. +// +// Kept deliberately tiny: just the dispatch table and the synchronization +// primitives. No budget, no pool, no policy. Sessions register/unregister +// their own dummy pointers during ``bind_placeholder_constants`` / their +// destructor. +// +// Pointer-reuse safety: a Session's destructor calls +// ``unregister_session(this)`` BEFORE the pool's ``cudaFreeAsync`` runs, +// so by the time the allocator can hand out a freed ``data_ptr()`` to a +// new allocation, no registry entry references it. ``lookup`` of an +// unknown pointer returns ``Error::NotFound`` (not a crash); the +// probe c-shim translates that to a hard failure with diagnostics, since +// it indicates a bug elsewhere (pass emitted a probe for an unbound +// weight, or two sessions collided). +class ProbeRegistry { + public: + static ProbeRegistry& instance(); + + // Bulk-register all dummies for one Session. Called once from + // ``Session::bind_placeholder_constants``. + ::executorch::runtime::Error register_dummies( + Session* session, + const std::vector>& dummy_to_fqn); + + // Bulk-unregister everything for one Session. Called once from + // ``Session::~Session``, BEFORE the pool is freed. + void unregister_session(Session* session); + + // Hot-path lookup. Returns the owning Session and the FQN string + // (string_view into the Session's storage; valid until the Session + // is destroyed). ``Error::NotFound`` if the pointer is unknown. + struct Lookup { + Session* session; + std::string_view fqn; + }; + ::executorch::runtime::Result lookup(void* dummy_data_ptr) const; + + private: + ProbeRegistry(); + ~ProbeRegistry(); + ProbeRegistry(const ProbeRegistry&) = delete; + ProbeRegistry& operator=(const ProbeRegistry&) = delete; + + struct Impl; + std::unique_ptr impl_; +}; + +} // namespace weight_offload +} // namespace cuda +} // namespace backends +} // namespace executorch