-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[#12632][feat] Initial prototype for AutoDeploy compile cache #12698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -442,69 +442,62 @@ def _add_missing_load_hooks(gm: fx.GraphModule, model: nn.Module) -> None: | |
| The following module names were not found in exported module {list(post_hooks.keys())}""" | ||
|
|
||
|
|
||
| def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None: | ||
| """ | ||
| Add a load hook to handle aliased parameters in the model. | ||
|
|
||
| When parameters are aliased (multiple parameter names point to the same tensor), | ||
| we need to ensure all aliases get the same value during loading. This hook: | ||
| 1. Identifies groups of aliased parameters | ||
| 2. For each group, finds a valid parameter value from the state dict | ||
| 3. Applies that value to all aliases in the group | ||
| def _build_aliasing_load_pre_hook( | ||
| aliased_groups: List[List[str]], | ||
| ) -> Callable: | ||
| """Build a load hook that broadcasts aliased parameter values. | ||
|
|
||
| Args: | ||
| gm: The graph module to add the hook to | ||
| model: The source model containing the original parameter aliases | ||
| aliased_groups: Each group is a list of parameter names that alias the same | ||
| tensor. The hook ensures all names in a group see the same value during | ||
| ``load_state_dict``. | ||
|
|
||
| Returns: | ||
| A callable suitable for ``_register_load_state_dict_pre_hook``. | ||
| """ | ||
|
|
||
| def find_valid_param_value( | ||
| def _find_valid_param_value( | ||
| state_dict: Dict[str, torch.Tensor], param_names: List[str] | ||
| ) -> Optional[torch.Tensor]: | ||
| """Find a valid parameter value from state dict for a group of aliased parameters. | ||
|
|
||
| Args: | ||
| state_dict: The state dict being loaded | ||
| param_names: List of parameter names that are aliases of each other | ||
|
|
||
| Returns: | ||
| A valid tensor value if found, None otherwise | ||
| """ | ||
| # First try to find a non-meta tensor value | ||
| value = None | ||
| for name in param_names: | ||
| if name in state_dict: | ||
| value = state_dict[name] | ||
| if value.device.type != "meta": | ||
| return value | ||
|
|
||
| return value | ||
|
|
||
| def aliasing_load_pre_hook(state_dict: Dict[str, torch.Tensor], prefix: str, *args, **kwargs): | ||
| """Load hook that ensures aliased parameters get the same value.""" | ||
| for group in aliased_groups: | ||
| # Find a valid value for this group of aliases | ||
| value = find_valid_param_value(state_dict, group) | ||
|
|
||
| value = _find_valid_param_value(state_dict, group) | ||
| if value is not None: | ||
| # Apply the value to all aliases | ||
| for name in group: | ||
| state_dict[name] = value | ||
|
|
||
| ad_logger.debug(f"Applied value from {group[0]} to aliased parameters: {group}") | ||
|
|
||
| # Find all parameter aliases in the source model | ||
| param_to_names = defaultdict(list) | ||
| return aliasing_load_pre_hook | ||
|
|
||
|
|
||
| def _add_load_hook_for_aliased_params(gm: fx.GraphModule, model: nn.Module) -> None: | ||
| """Add a load hook to handle aliased parameters in the model. | ||
|
|
||
| When parameters are aliased (multiple parameter names point to the same tensor), | ||
| we need to ensure all aliases get the same value during loading. | ||
|
|
||
| Args: | ||
| gm: The graph module to add the hook to | ||
| model: The source model containing the original parameter aliases | ||
| """ | ||
| param_to_names: Dict[int, List[str]] = defaultdict(list) | ||
| for name, param in model.named_parameters(remove_duplicate=False): | ||
| param_to_names[id(param)].append(name) | ||
|
|
||
| # Filter to only groups with multiple aliases | ||
| aliased_groups = [names for names in param_to_names.values() if len(names) > 1] | ||
|
|
||
| if not aliased_groups: | ||
| return | ||
|
|
||
| # Register the hook | ||
| gm._register_load_state_dict_pre_hook(aliasing_load_pre_hook) | ||
| gm._register_load_state_dict_pre_hook(_build_aliasing_load_pre_hook(aliased_groups)) | ||
|
|
||
|
|
||
| def _rename_nodes_with_module_hierarchy(gm: fx.GraphModule) -> None: | ||
|
|
@@ -583,6 +576,28 @@ def _clean_up_assertions_and_guards(gm: fx.GraphModule): | |
| canonicalize_graph(gm) | ||
|
|
||
|
|
||
| def _remove_export_input_constraint_hooks(gm: fx.GraphModule) -> None: | ||
| """Remove ``_check_input_constraints_pre_hook`` added by ``torch.export``. | ||
|
|
||
| ``ep.module()`` attaches a forward pre-hook that validates inputs against | ||
| static shape constraints from the export call. The AutoDeploy pipeline | ||
| manages input shapes dynamically, so these hooks must be stripped to avoid | ||
| spurious ``RuntimeError`` during transforms like ``resize_kv_cache``. | ||
| """ | ||
| hooks_to_remove = [] | ||
| for handle_id, hook in gm._forward_pre_hooks.items(): | ||
| fn = hook if not hasattr(hook, "__func__") else hook.__func__ | ||
| name = getattr(fn, "__name__", "") or getattr(fn, "__qualname__", "") | ||
| if "check_input_constraints" in name: | ||
| hooks_to_remove.append(handle_id) | ||
|
|
||
| for handle_id in hooks_to_remove: | ||
| del gm._forward_pre_hooks[handle_id] | ||
|
|
||
| if hooks_to_remove: | ||
| ad_logger.debug(f"Removed {len(hooks_to_remove)} export input constraint hook(s)") | ||
|
|
||
|
|
||
| def run_forward_for_capture( | ||
| model: nn.Module, | ||
| capture_fn: Optional[Callable[..., nn.Module]] = None, | ||
|
|
@@ -723,6 +738,11 @@ def _capture_fn(model, args, kwargs): | |
| # clean up checks --> generally the sanity checks are overly conservative and we can remove them | ||
| _clean_up_assertions_and_guards(egm) | ||
|
|
||
| # Remove input constraint hooks added by torch.export — the AutoDeploy pipeline | ||
| # manages input shapes dynamically and these hooks would reject valid inputs | ||
| # during resize_kv_cache and other forward passes with different batch sizes. | ||
| _remove_export_input_constraint_hooks(egm) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's great. It's a good time to just get rid of the input constraints alltogether |
||
|
|
||
| # Rename nodes to reflect module hierarchy for better debuggability | ||
| _rename_nodes_with_module_hierarchy(egm) | ||
|
|
||
|
|
@@ -780,6 +800,7 @@ def export_onnx(ad_config: "LlmArgs") -> nn.Module: | |
| inference_optimizer = InferenceOptimizer( | ||
| factory=factory, | ||
| config=ad_config.transforms, | ||
| pipeline_cache_config=ad_config.pipeline_cache, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see my other comment: ideally, we don't need this. |
||
| ) | ||
|
|
||
| # 4. Run the transform pipeline (includes export_to_onnx transform) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| from pydantic import Field, ValidationInfo, field_validator, model_validator | ||
| from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
|
||
| from tensorrt_llm.llmapi.utils import StrictBaseModel | ||
| from tensorrt_llm.mapping import Mapping | ||
|
|
||
| from ...llmapi.llm_args import ( | ||
|
|
@@ -16,6 +17,7 @@ | |
| _ParallelConfig, | ||
| ) | ||
| from .models import ModelFactory, ModelFactoryRegistry | ||
| from .transform.interface import Stages | ||
| from .utils._config import DynamicYamlMixInForSettings | ||
| from .utils.logger import ad_logger | ||
|
|
||
|
|
@@ -55,6 +57,24 @@ def _shortcut_description(description: str, shortcut: str) -> str: | |
| return f"{description} Alias for: {long_names_str}." | ||
|
|
||
|
|
||
| class PipelineCacheConfig(StrictBaseModel): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my other comments: could just be another transform. |
||
| """Configuration for the portable AutoDeploy pipeline snapshot cache.""" | ||
|
|
||
| enabled: bool = Field( | ||
| default=False, | ||
| description="Whether to enable pipeline snapshot caching for AutoDeploy.", | ||
| ) | ||
| root: Path = Field( | ||
| default_factory=lambda: Path.home() / ".cache" / "tensorrt_llm" / "auto_deploy_pipeline", | ||
| description="Root directory used to store AutoDeploy pipeline snapshots.", | ||
| ) | ||
| boundary: str = Field( | ||
| default="sharding_transform_executor", | ||
| description="Pipeline boundary transform name used for snapshot save/restore. The " | ||
| "boundary must be at or before the sharding stage (pre-weight-loading).", | ||
| ) | ||
|
|
||
|
|
||
| class LlmArgs(DynamicYamlMixInForSettings, TorchLlmArgs, BaseSettings): | ||
| """LlmArgs config class for providing full expert configurability of the AutoDeploy backend.""" | ||
|
|
||
|
|
@@ -201,8 +221,8 @@ def validate_and_init_tokenizer(self): | |
| default_factory=dict, | ||
| description="Extra kwargs for the tokenizer class to customize the tokenizer. Same as " | ||
| "model_kwargs. For example, the default HF Llama tokenizer can be initialized with the " | ||
| "arguments specified here: " | ||
| "https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama_fast.py#L127.", | ||
| "arguments specified here: https://github.com/huggingface/transformers/blob/main/src/" | ||
| "transformers/models/llama/tokenization_llama_fast.py#L127.", | ||
| ) | ||
|
|
||
| ### RUNTIME FEATURES ########################################################################### | ||
|
|
@@ -240,6 +260,10 @@ def validate_and_init_tokenizer(self): | |
| description="A dictionary of transform configurations. The key is the transform name and " | ||
| "the value is the transform configuration.", | ||
| ) | ||
| pipeline_cache: PipelineCacheConfig = Field( | ||
| default_factory=PipelineCacheConfig, | ||
| description="Configuration for the AutoDeploy pipeline snapshot cache.", | ||
| ) | ||
|
|
||
| ### SHORTCUTS FOR COMMON INFERENCE OPTIMIZER CONFIGS ########################################### | ||
| compile_backend: str = Field( | ||
|
|
@@ -350,6 +374,26 @@ def cap_max_batch_size_to_max_num_tokens(self): | |
| self.max_batch_size = self.max_num_tokens | ||
| return self | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_pipeline_cache(self): | ||
| if not self.pipeline_cache.enabled: | ||
| return self | ||
|
|
||
| boundary_name = self.pipeline_cache.boundary | ||
| if boundary_name not in self.transforms: | ||
| raise ValueError( | ||
| f"Pipeline cache boundary '{boundary_name}' is not present in transforms." | ||
| ) | ||
|
|
||
| boundary_stage = Stages(self.transforms[boundary_name]["stage"]) | ||
| if boundary_stage > Stages.SHARDING: | ||
| raise ValueError( | ||
| "The pipeline cache only supports pre-weight-loading boundaries through " | ||
| f"sharding. Got '{boundary_name}' at stage '{boundary_stage.value}'." | ||
| ) | ||
|
|
||
| return self | ||
|
|
||
| ### UTILITY METHODS ############################################################################ | ||
| def create_factory(self) -> ModelFactory: | ||
| """Create a model factory from the arguments. | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to at least also add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would love for this to just be another transform. The way I think about it is you can insert it at the point where you check for the cache, and then we have a shared config in
interface.py. You can use that shared config and modify that transform/interface.py file to then check that shared config and just skip over transforms. If there was a cache hit, what do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good idea, will update.