Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# This is the set of transforms running in "graph" mode. In this mode, we capture the full graph
# of the model and optimize it for inference.
pipeline_cache:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would love for this to just be another transform. The way I think about it is you can insert it at the point where you check for the cache, and then we have a shared config in interface.py. You can use that shared config and modify that transform/interface.py file to then check that shared config and just skip over transforms. If there was a cache hit, what do you think?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, will update.

enabled: true
boundary: sharding_transform_executor

transforms:
############################################################################################
# BUILD MODEL, EXPORT TO GRAPH MODULE, AND CLEAN UP
Expand Down
89 changes: 55 additions & 34 deletions tensorrt_llm/_torch/auto_deploy/export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,69 +442,62 @@ def _add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> None:
The following module names were not found in exported module {list(post_hooks.keys())}"""


def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None:
"""
Add a load hook to handle aliased parameters in the model.

When parameters are aliased (multiple parameter names point to the same tensor),
we need to ensure all aliases get the same value during loading. This hook:
1. Identifies groups of aliased parameters
2. For each group, finds a valid parameter value from the state dict
3. Applies that value to all aliases in the group
def _build_aliasing_load_pre_hook(
aliased_groups: List[List[str]],
) -> Callable:
"""Build a load hook that broadcasts aliased parameter values.

Args:
gm: The graph module to add the hook to
model: The source model containing the original parameter aliases
aliased_groups: Each group is a list of parameter names that alias the same
tensor. The hook ensures all names in a group see the same value during
``load_state_dict``.

Returns:
A callable suitable for ``_register_load_state_dict_pre_hook``.
"""

def find_valid_param_value(
def _find_valid_param_value(
state_dict: Dict[str, torch.Tensor], param_names: List[str]
) -> Optional[torch.Tensor]:
"""Find a valid parameter value from state dict for a group of aliased parameters.

Args:
state_dict: The state dict being loaded
param_names: List of parameter names that are aliases of each other

Returns:
A valid tensor value if found, None otherwise
"""
# First try to find a non-meta tensor value
value = None
for name in param_names:
if name in state_dict:
value = state_dict[name]
if value.device.type != "meta":
return value

return value

def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *args, **kwargs):
"""Load hook that ensures aliased parameters get the same value."""
for group in aliased_groups:
# Find a valid value for this group of aliases
value = find_valid_param_value(state_dict, group)

value = _find_valid_param_value(state_dict, group)
if value is not None:
# Apply the value to all aliases
for name in group:
state_dict[name] = value

ad_logger.debug(f"Applied value from {group[0]} to aliased parameters: {group}")

# Find all parameter aliases in the source model
param_to_names = defaultdict(list)
return aliasing_load_pre_hook


def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None:
"""Add a load hook to handle aliased parameters in the model.

When parameters are aliased (multiple parameter names point to the same tensor),
we need to ensure all aliases get the same value during loading.

Args:
gm: The graph module to add the hook to
model: The source model containing the original parameter aliases
"""
param_to_names: Dict[int, List[str]] = defaultdict(list)
for name, param in model.named_parameters(remove_duplicate=False):
param_to_names[id(param)].append(name)

# Filter to only groups with multiple aliases
aliased_groups = [names for names in param_to_names.values() if len(names) > 1]

if not aliased_groups:
return

# Register the hook
gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook)
gm._register_load_state_dict_pre_hook(_build_aliasing_load_pre_hook(aliased_groups))


def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None:
Expand Down Expand Up @@ -583,6 +576,28 @@ def _clean_up_assertions_and_guards(gm: fx.GraphModule):
canonicalize_graph(gm)


def _remove_export_input_constraint_hooks(gm: fx.GraphModule) -> None:
"""Remove ``_check_input_constraints_pre_hook`` added by ``torch.export``.

``ep.module()`` attaches a forward pre-hook that validates inputs against
static shape constraints from the export call. The AutoDeploy pipeline
manages input shapes dynamically, so these hooks must be stripped to avoid
spurious ``RuntimeError`` during transforms like ``resize_kv_cache``.
"""
hooks_to_remove = []
for handle_id, hook in gm._forward_pre_hooks.items():
fn = hook if not hasattr(hook, "__func__") else hook.__func__
name = getattr(fn, "__name__", "") or getattr(fn, "__qualname__", "")
if "check_input_constraints" in name:
hooks_to_remove.append(handle_id)

for handle_id in hooks_to_remove:
del gm._forward_pre_hooks[handle_id]

if hooks_to_remove:
ad_logger.debug(f"Removed {len(hooks_to_remove)} export input constraint hook(s)")


def run_forward_for_capture(
model: nn.Module,
capture_fn: Optional[Callable[..., nn.Module]] = None,
Expand Down Expand Up @@ -723,6 +738,11 @@ def _capture_fn(model, args, kwargs):
# clean up checks --> generally the sanity checks are overly conservative and we can remove them
_clean_up_assertions_and_guards(egm)

# Remove input constraint hooks added by torch.export — the AutoDeploy pipeline
# manages input shapes dynamically and these hooks would reject valid inputs
# during resize_kv_cache and other forward passes with different batch sizes.
_remove_export_input_constraint_hooks(egm)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great. It's a good time to just get rid of the input constraints alltogether


# Rename nodes to reflect module hierarchy for better debuggability
_rename_nodes_with_module_hierarchy(egm)

Expand Down Expand Up @@ -780,6 +800,7 @@ def export_onnx(ad_config: "LlmArgs") -> nn.Module:
inference_optimizer = InferenceOptimizer(
factory=factory,
config=ad_config.transforms,
pipeline_cache_config=ad_config.pipeline_cache,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my other comment: ideally, we don't need this.

)

# 4. Run the transform pipeline (includes export_to_onnx transform)
Expand Down
48 changes: 46 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import Field, ValidationInfo, field_validator, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from tensorrt_llm.llmapi.utils import StrictBaseModel
from tensorrt_llm.mapping import Mapping

from ...llmapi.llm_args import (
Expand All @@ -16,6 +17,7 @@
_ParallelConfig,
)
from .models import ModelFactory, ModelFactoryRegistry
from .transform.interface import Stages
from .utils._config import DynamicYamlMixInForSettings
from .utils.logger import ad_logger

Expand Down Expand Up @@ -55,6 +57,24 @@ def _shortcut_description(description: str, shortcut: str) -> str:
return f"{description} Alias for: {long_names_str}."


class PipelineCacheConfig(StrictBaseModel):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comments: could just be another transform.

"""Configuration for the portable AutoDeploy pipeline snapshot cache."""

enabled: bool = Field(
default=False,
description="Whether to enable pipeline snapshot caching for AutoDeploy.",
)
root: Path = Field(
default_factory=lambda: Path.home() / ".cache" / "tensorrt_llm" / "auto_deploy_pipeline",
description="Root directory used to store AutoDeploy pipeline snapshots.",
)
boundary: str = Field(
default="sharding_transform_executor",
description="Pipeline boundary transform name used for snapshot save/restore. The "
"boundary must be at or before the sharding stage (pre-weight-loading).",
)


class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings):
"""LlmArgs config class for providing full expert configurability of the AutoDeploy backend."""

Expand Down Expand Up @@ -201,8 +221,8 @@ def validate_and_init_tokenizer(self):
default_factory=dict,
description="Extra kwargs for the tokenizer class to customize the tokenizer. Same as "
"model_kwargs. For example, the default HF Llama tokenizer can be initialized with the "
"arguments specified here: "
"https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127.",
"arguments specified here: https://github.com/huggingface/transformers/blob/main/src/"
"transformers/models/llama/tokenization_llama_fast.py#L127.",
)

### RUNTIME FEATURES ###########################################################################
Expand Down Expand Up @@ -240,6 +260,10 @@ def validate_and_init_tokenizer(self):
description="A dictionary of transform configurations. The key is the transform name and "
"the value is the transform configuration.",
)
pipeline_cache: PipelineCacheConfig = Field(
default_factory=PipelineCacheConfig,
description="Configuration for the AutoDeploy pipeline snapshot cache.",
)

### SHORTCUTS FOR COMMON INFERENCE OPTIMIZER CONFIGS ###########################################
compile_backend: str = Field(
Expand Down Expand Up @@ -350,6 +374,26 @@ def cap_max_batch_size_to_max_num_tokens(self):
self.max_batch_size = self.max_num_tokens
return self

@model_validator(mode="after")
def validate_pipeline_cache(self):
if not self.pipeline_cache.enabled:
return self

boundary_name = self.pipeline_cache.boundary
if boundary_name not in self.transforms:
raise ValueError(
f"Pipeline cache boundary '{boundary_name}' is not present in transforms."
)

boundary_stage = Stages(self.transforms[boundary_name]["stage"])
if boundary_stage > Stages.SHARDING:
raise ValueError(
"The pipeline cache only supports pre-weight-loading boundaries through "
f"sharding. Got '{boundary_name}' at stage '{boundary_stage.value}'."
)

return self

### UTILITY METHODS ############################################################################
def create_factory(self) -> ModelFactory:
"""Create a model factory from the arguments.
Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to at least also add model_kwargs, right?

Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ def vocab_size_padded(self) -> Optional[int]:
"""
return None

def get_pipeline_cache_model_identifier(self) -> str:
"""Return a stable model identifier for pipeline cache key generation."""
return str(self._model)

def get_pipeline_cache_checkpoint_fingerprint(self) -> str:
"""Return a checkpoint fingerprint input for pipeline cache key generation.

The default implementation uses the configured model identifier/path directly. Model
factories may override this with a stronger fingerprint if they can determine one cheaply.
"""
return str(self._model)

def build_model(self, device: str) -> nn.Module:
"""Build the model on the desired device.

Expand Down
108 changes: 108 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Interface to initialize and load HF models."""

import hashlib
import json
import math
import operator
Expand Down Expand Up @@ -53,6 +54,80 @@
from .quant_config_reader import QuantConfigReader, autodetect_quant_config_reader


def _extract_hf_snapshot_sha(checkpoint_dir: str) -> Optional[str]:
"""Return the HF commit SHA when *checkpoint_dir* is an HF snapshot path.

``huggingface_hub.snapshot_download`` resolves to a content-addressed path
of the form
``.../models--<org>--<repo>/snapshots/<commit_sha>/``. When the path
matches, the commit SHA is a strong fingerprint on its own — no need to
walk the directory.
"""
parts = os.path.realpath(checkpoint_dir).split(os.sep)
try:
snapshots_idx = parts.index("snapshots")
except ValueError:
return None
if snapshots_idx + 1 >= len(parts):
return None
sha = parts[snapshots_idx + 1]
# HF commit SHAs are 40-char lowercase hex. Require at least 7 hex chars
# (git short-SHA) to avoid false positives on paths that happen to have a
# ``snapshots`` directory.
if len(sha) >= 7 and all(c in "0123456789abcdef" for c in sha.lower()):
return sha
return None


def _hash_checkpoint_metadata(checkpoint_dir: str) -> str:
"""Return a content-based SHA256 fingerprint of a checkpoint directory.

Fast path: an HF snapshot directory (``.../snapshots/<sha>/``) returns the
commit SHA directly — HF's content-addressed storage already guarantees it
uniquely identifies the snapshot.

Otherwise walk the directory in sorted order. Weight shards are
fingerprinted by relative-path + file size (reading multi-GB tensors on the
hot path would defeat the cache); every other file contributes its full
content. This handles any future metadata file without an allowlist.
"""
sha = _extract_hf_snapshot_sha(checkpoint_dir)
if sha is not None:
return f"hf_snapshot:{sha}"

weight_suffixes = (".safetensors", ".bin", ".pt", ".pth", ".gguf")
h = hashlib.sha256()
ckpt_path = os.path.realpath(checkpoint_dir)

for root, dirs, files in os.walk(ckpt_path):
dirs.sort()
files.sort()
rel_root = os.path.relpath(root, ckpt_path)
for fname in files:
fpath = os.path.join(root, fname)
rel_path = fname if rel_root == "." else os.path.join(rel_root, fname)
rel_path = rel_path.replace(os.sep, "/")

if fname.endswith(weight_suffixes):
try:
size = os.path.getsize(fpath)
except OSError:
h.update(f"shard-missing:{rel_path}\n".encode("utf-8"))
continue
h.update(f"shard:{rel_path}:{size}\n".encode("utf-8"))
continue

h.update(f"file:{rel_path}\n".encode("utf-8"))
try:
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(1 << 20), b""):
h.update(chunk)
except OSError:
h.update(b"unreadable\n")

return h.hexdigest()


@contextmanager
def hf_load_state_dict_with_device(device: DeviceLikeType):
"""Patch HF loading utilities according to our needs.
Expand Down Expand Up @@ -310,6 +385,39 @@ def _set_sharding_config(self, model_config: PretrainedConfig):
if hasattr(model_config, "base_model_tp_plan"):
self._sharding_config["tp_plan"] = model_config.base_model_tp_plan

def get_pipeline_cache_model_identifier(self) -> str:
"""Stable model identifier for the pipeline cache.

Uses the repo id / path as configured, not the resolved snapshot
directory. The content fingerprint (see below) handles the case where
the same path points to different checkpoint contents over time.
"""
return str(self._model)

def get_pipeline_cache_checkpoint_fingerprint(self) -> str:
"""Content-based fingerprint over checkpoint metadata.

Hashes files that determine the graph structure or the tensor layout
loaded from disk:

- ``config.json``, ``generation_config.json`` (if present)
- ``*.safetensors.index.json`` / ``pytorch_model.bin.index.json``
(if present)
- each weight shard's name + file size (not its contents — reading
multi-GB weights on the hot path defeats the cache)

Falls back to the configured path string if the resolved checkpoint
directory is not yet available (e.g. before ``prefetch_checkpoint``).
"""
try:
checkpoint_dir = self.model
if not checkpoint_dir or not os.path.isdir(checkpoint_dir):
return str(self._model)
return _hash_checkpoint_metadata(checkpoint_dir)
except Exception as exc: # noqa: BLE001
ad_logger.debug(f"Pipeline cache: falling back to path-based fingerprint ({exc})")
return str(self._model)

def get_quant_config(self) -> Dict:
"""Returns the quantization config for this model or an empty dict if not quantized."""
if self._quant_config_reader is not None:
Expand Down
Loading
Loading