[#12632][feat] Initial prototype for AutoDeploy compile cache#12698
[#12632][feat] Initial prototype for AutoDeploy compile cache#12698nvchenghaoz wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
📝 WalkthroughWalkthroughA pre-weight-loading pipeline cache feature is added to TensorRT-LLM's AutoDeploy, enabling snapshot save/restore of FX graphs at configured transform boundaries (defaulting to after sharding). Key additions include an AD IR serialization format for graph persistence, a Changes
Sequence DiagramsequenceDiagram
actor User
participant Optimizer as InferenceOptimizer
participant CacheMgr as PipelineSnapshotManager
participant SharedCfg as SharedConfig
participant Transform as BaseTransform
participant Module as nn.Module (GraphModule)
participant IR as AD IR Handler
User->>Optimizer: __call__(module, config)
activate Optimizer
Optimizer->>CacheMgr: maybe_restore(module)
activate CacheMgr
CacheMgr->>CacheMgr: Select valid boundary, validate manifest
alt Cache Hit
CacheMgr->>IR: load_ir(rank_dir)
activate IR
IR-->>IR: Deserialize graph from JSON
IR-->>IR: Reconstruct GraphModule
IR-->>IR: Hydrate shapes via FakeTensorProp
IR-->>CacheMgr: Return module, real_buffers
deactivate IR
CacheMgr->>CacheMgr: Reattach hooks from specs
CacheMgr->>CacheMgr: Replay source-model hooks
CacheMgr-->>Optimizer: (restored_module, start_idx)
else Cache Miss
CacheMgr-->>Optimizer: (None, 0)
end
deactivate CacheMgr
loop For each transform [start_idx → end]
Optimizer->>Transform: __call__(module)
activate Transform
Transform->>Module: Execute transform logic
activate Module
Module-->>Transform: Transformed module
deactivate Module
Transform->>SharedCfg: Update autodeploy_meta
Transform->>CacheMgr: maybe_save(transform_name, idx, config, module)
activate CacheMgr
alt At Cache Boundary
CacheMgr->>IR: extract_ir(module, hook_specs)
activate IR
IR->>IR: Walk FX graph, serialize targets/args
IR->>IR: Capture hook specs from load hooks
IR-->>CacheMgr: (IRGraph, real_buffers_dict)
deactivate IR
CacheMgr->>CacheMgr: Synchronize ranks (distributed barrier)
CacheMgr->>CacheMgr: Write manifest.json + ad_ir.json
CacheMgr->>CacheMgr: Write real_buffers.pt
else Not at boundary
CacheMgr->>CacheMgr: Skip (no-op)
end
deactivate CacheMgr
deactivate Transform
end
Optimizer-->>User: Optimized module
deactivate Optimizer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (5)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (1)
258-282: Consider defensive handling for deserialization edge cases.The deserialization logic assumes the stored dictionary structure matches
TransformInfofields exactly. If the cache format evolves, this could raiseTypeErroron field mismatch.Consider wrapping the
TransformInfo(**value)call in a try-except that falls back gracefully or logs a warning when encountering incompatible cached data.♻️ Optional: Add defensive handling for format changes
def deserialize_autodeploy_meta(autodeploy_meta: Dict[str, Any]) -> AutodeployMeta: """Restore typed AutoDeploy metadata from a serialized dictionary.""" restored = dict(autodeploy_meta) history = restored.get(BaseTransform._history_key, {}) - restored[BaseTransform._history_key] = { - name: value if isinstance(value, TransformInfo) else TransformInfo(**value) - for name, value in history.items() - } + restored_history = {} + for name, value in history.items(): + if isinstance(value, TransformInfo): + restored_history[name] = value + else: + try: + restored_history[name] = TransformInfo(**value) + except TypeError: + # Fallback for incompatible cached format + restored_history[name] = TransformInfo() + restored[BaseTransform._history_key] = restored_history return restored🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/interface.py` around lines 258 - 282, The deserializer in deserialize_autodeploy_meta assumes history entries can be directly instantiated via TransformInfo(**value); wrap that instantiation in a try/except to defensively handle incompatible formats: on exception log a warning (including the entry name and error) and fall back to either keeping the raw dict value or creating a minimal TransformInfo placeholder so the rest of restored data remains usable; ensure you reference BaseTransform._history_key and preserve non-history keys when reconstructing restored.tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py (2)
984-991: Addmap_locationfor cross-device compatibility.When loading
real_buffers.pt, consider specifyingmap_locationto handle cases where the cache was created on a different device (e.g., different GPU).♻️ Proposed fix
if buf_path.exists(): try: - real_buffers = torch.load(buf_path, weights_only=True) + real_buffers = torch.load(buf_path, weights_only=True, map_location="cpu") except TypeError: - real_buffers = torch.load(buf_path) + real_buffers = torch.load(buf_path, map_location="cpu")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py` around lines 984 - 991, The code loading real_buffers from buf_path should pass a map_location to torch.load for cross-device compatibility; update both torch.load calls (the one with weights_only=True inside the try and the fallback in except) to include map_location="cpu" (or torch.device("cpu")) so cached tensors created on a different GPU/ CUDA device are remapped to CPU when loaded, preserving the existing TypeError fallback behavior and variable real_buffers.
175-206: Useraise ... fromfor exception chaining.Per Ruff B904, exceptions raised within
exceptblocks should useraise ... from errorraise ... from Noneto preserve the exception chain.♻️ Proposed fix
if hasattr(parent, leaf): return getattr(parent, leaf) - raise ModuleNotFoundError(f"Cannot resolve target: {key}") + raise ModuleNotFoundError(f"Cannot resolve target: {key}") from None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py` around lines 175 - 206, The code currently raises ModuleNotFoundError inside an except block without chaining the original import error; change the outer exception handler to capture the original exception (e.g., except (ModuleNotFoundError, ImportError) as err:) and when re-raising use "raise ModuleNotFoundError(f'Cannot resolve target: {key}') from err" so the original importlib.import_module error is preserved; ensure any other explicit raises in this block also use "from err" where appropriate (referencing key, module_path, importlib.import_module, and ModuleNotFoundError).tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py (2)
1035-1053: Validateactive_argsagainst unknown names before assignment.The
_restore_sidecarsmethod validates that eacharg_nameexists incm.info.available_argsbut raises aValueErrorafter iterating. Consider failing fast on the first unknown arg.Also, assigning to private attributes (
_active_args,_active_host_prep_args,_use_flattened_layout) creates tight coupling withSequenceInfointernals. Consider adding public setter methods or documenting this coupling.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py` around lines 1035 - 1053, The loop in _restore_sidecars currently collects active_args but only raises ValueError after iteration; change it to fail fast by checking each arg_name against cm.info.available_args and immediately raising ValueError on the first unknown name to avoid partial state changes; also stop writing directly to private attributes cm.info._active_args, cm.info._active_host_prep_args, and cm.info._use_flattened_layout — instead add/use public setter methods on the SequenceInfo object (e.g., set_active_args, set_active_host_prep_args, set_use_flattened_layout) or, if setters are not available, document this tight coupling and wrap the assignments in a single helper method on cm.info to centralize and make the intent explicit.
434-447: Addstrict=Truetozip()call.Per static analysis (B905),
zip()should have an explicitstrict=parameter to catch length mismatches betweenco_freevarsandclosure.♻️ Proposed fix
- for name, cell in zip(code.co_freevars, closure): + for name, cell in zip(code.co_freevars, closure, strict=True):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py` around lines 434 - 447, The zip between code.co_freevars and closure in _extract_closure_vars can silently ignore length mismatches; update the zip call to use explicit strict=True (i.e., zip(code.co_freevars, closure, strict=True)) so any mismatch raises immediately, leaving the rest of the function logic intact and still handling ValueError from cell.cell_contents as before.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/auto_deploy/cache_design.md`:
- Around line 10-14: The fenced code block containing the pipeline diagram
(lines showing "FACTORY → EXPORT → POST_EXPORT → PATTERN_MATCHER → SHARDING →
[CACHE BOUNDARY] → WEIGHT_LOAD → POST_LOAD_FUSION → CACHE_INIT → COMPILE")
should include a language specifier; update the opening fence from ``` to
```text so the diagram is rendered as text and satisfies the linter.
In `@tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py`:
- Around line 325-341: Add a clear comment above the _serialize_treespec and
_deserialize_treespec functions documenting the security model for using pickle:
state that pickle is used intentionally for TreeSpec (no JSON alternative),
explain that untrusted/remote cache files may execute arbitrary code when
deserialized, and instruct that cached/serialized data must only be loaded from
trusted sources or validated beforehand; reference the functions
_serialize_treespec and _deserialize_treespec in the comment so reviewers can
find the trust model easily.
In `@tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py`:
- Around line 112-124: The Bandit B607 warning comes from invoking the "git"
executable by name in _repo_git_sha; resolve the executable first and use the
absolute path or return None to satisfy the scanner. Update _repo_git_sha to
call shutil.which("git") (import shutil) to get git_path, if git_path is falsy
return None, then call subprocess.run with [git_path, "rev-parse", "HEAD"]
(keeping cwd=_repo_root(), capture_output=True, check=True, text=True) and
preserve the existing exception handling; this explicitly documents/ensures the
executable path is not partial and satisfies the security scanner.
---
Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py`:
- Around line 984-991: The code loading real_buffers from buf_path should pass a
map_location to torch.load for cross-device compatibility; update both
torch.load calls (the one with weights_only=True inside the try and the fallback
in except) to include map_location="cpu" (or torch.device("cpu")) so cached
tensors created on a different GPU/ CUDA device are remapped to CPU when loaded,
preserving the existing TypeError fallback behavior and variable real_buffers.
- Around line 175-206: The code currently raises ModuleNotFoundError inside an
except block without chaining the original import error; change the outer
exception handler to capture the original exception (e.g., except
(ModuleNotFoundError, ImportError) as err:) and when re-raising use "raise
ModuleNotFoundError(f'Cannot resolve target: {key}') from err" so the original
importlib.import_module error is preserved; ensure any other explicit raises in
this block also use "from err" where appropriate (referencing key, module_path,
importlib.import_module, and ModuleNotFoundError).
In `@tensorrt_llm/_torch/auto_deploy/transform/interface.py`:
- Around line 258-282: The deserializer in deserialize_autodeploy_meta assumes
history entries can be directly instantiated via TransformInfo(**value); wrap
that instantiation in a try/except to defensively handle incompatible formats:
on exception log a warning (including the entry name and error) and fall back to
either keeping the raw dict value or creating a minimal TransformInfo
placeholder so the rest of restored data remains usable; ensure you reference
BaseTransform._history_key and preserve non-history keys when reconstructing
restored.
In `@tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py`:
- Around line 1035-1053: The loop in _restore_sidecars currently collects
active_args but only raises ValueError after iteration; change it to fail fast
by checking each arg_name against cm.info.available_args and immediately raising
ValueError on the first unknown name to avoid partial state changes; also stop
writing directly to private attributes cm.info._active_args,
cm.info._active_host_prep_args, and cm.info._use_flattened_layout — instead
add/use public setter methods on the SequenceInfo object (e.g., set_active_args,
set_active_host_prep_args, set_use_flattened_layout) or, if setters are not
available, document this tight coupling and wrap the assignments in a single
helper method on cm.info to centralize and make the intent explicit.
- Around line 434-447: The zip between code.co_freevars and closure in
_extract_closure_vars can silently ignore length mismatches; update the zip call
to use explicit strict=True (i.e., zip(code.co_freevars, closure, strict=True))
so any mismatch raises immediately, leaving the rest of the function logic
intact and still handling ValueError from cell.cell_contents as before.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7c074f32-6936-4ef1-a3c0-805c22ff7add
📒 Files selected for processing (12)
examples/auto_deploy/cache_design.mdtensorrt_llm/_torch/auto_deploy/config/default.yamltensorrt_llm/_torch/auto_deploy/export/export.pytensorrt_llm/_torch/auto_deploy/llm_args.pytensorrt_llm/_torch/auto_deploy/models/factory.pytensorrt_llm/_torch/auto_deploy/shim/ad_executor.pytensorrt_llm/_torch/auto_deploy/transform/ad_ir.pytensorrt_llm/_torch/auto_deploy/transform/interface.pytensorrt_llm/_torch/auto_deploy/transform/optimizer.pytensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.pytests/unittest/auto_deploy/singlegpu/shim/test_llm_config.pytests/unittest/auto_deploy/singlegpu/transformations/test_pipeline_cache.py
| ``` | ||
| FACTORY → EXPORT → POST_EXPORT → PATTERN_MATCHER → SHARDING | ||
| → [CACHE BOUNDARY] → | ||
| WEIGHT_LOAD → POST_LOAD_FUSION → CACHE_INIT → COMPILE | ||
| ``` |
There was a problem hiding this comment.
Add language specifier to fenced code block.
The pipeline stages diagram is missing a language specifier. Add text to improve rendering and satisfy linter requirements.
-```
+```text
FACTORY → EXPORT → POST_EXPORT → PATTERN_MATCHER → SHARDING
→ [CACHE BOUNDARY] →
WEIGHT_LOAD → POST_LOAD_FUSION → CACHE_INIT → COMPILE
<details>
<summary>🧰 Tools</summary>
<details>
<summary>🪛 markdownlint-cli2 (0.22.0)</summary>
[warning] 10-10: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
</details>
</details>
<details>
<summary>🤖 Prompt for AI Agents</summary>
Verify each finding against the current code and only fix it if needed.
In @examples/auto_deploy/cache_design.md around lines 10 - 14, The fenced code
block containing the pipeline diagram (lines showing "FACTORY → EXPORT →
POST_EXPORT → PATTERN_MATCHER → SHARDING → [CACHE BOUNDARY] → WEIGHT_LOAD →
POST_LOAD_FUSION → CACHE_INIT → COMPILE") should include a language specifier;
update the opening fence from totext so the diagram is rendered as text
and satisfies the linter.
</details>
<!-- fingerprinting:phantom:medusa:ocelot:c3818d17-fad8-4786-a848-055fee2e698c -->
<!-- This is an auto-generated comment by CodeRabbit -->
| def _serialize_treespec(spec: Any) -> Optional[str]: | ||
| if spec is None: | ||
| return None | ||
| try: | ||
| buf = io.BytesIO() | ||
| pickle.dump(spec, buf) | ||
| return base64.b64encode(buf.getvalue()).decode("ascii") | ||
| except Exception: # noqa: BLE001 | ||
| return None | ||
|
|
||
|
|
||
| def _deserialize_treespec(data: Optional[str]) -> Any: | ||
| if data is None: | ||
| return None | ||
| raw = base64.b64decode(data.encode("ascii")) | ||
| return pickle.loads(raw) # noqa: S301 | ||
|
|
There was a problem hiding this comment.
Document security model for pickle usage.
Bandit flags pickle usage (B301/B403). While necessary for TreeSpec serialization (no JSON alternative exists), the security implications should be documented. Cache files from untrusted sources could execute arbitrary code.
Consider adding a comment documenting the trust model:
📝 Proposed documentation
def _deserialize_treespec(data: Optional[str]) -> Any:
+ # SECURITY: TreeSpec deserialization uses pickle. Cache files must come from
+ # trusted sources (same machine/user). Do not load pipeline cache from
+ # untrusted network locations.
if data is None:
return None
raw = base64.b64decode(data.encode("ascii"))
return pickle.loads(raw) # noqa: S301🧰 Tools
🪛 GitHub Actions: Release Checks
[error] 340-340: Bandit (B301:blacklist): Pickle and modules that wrap it can be unsafe when used to deserialize untrusted data. Location: tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py:340:11
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/transform/ad_ir.py` around lines 325 - 341,
Add a clear comment above the _serialize_treespec and _deserialize_treespec
functions documenting the security model for using pickle: state that pickle is
used intentionally for TreeSpec (no JSON alternative), explain that
untrusted/remote cache files may execute arbitrary code when deserialized, and
instruct that cached/serialized data must only be loaded from trusted sources or
validated beforehand; reference the functions _serialize_treespec and
_deserialize_treespec in the comment so reviewers can find the trust model
easily.
| def _repo_git_sha() -> Optional[str]: | ||
| try: | ||
| proc = subprocess.run( | ||
| ["git", "rev-parse", "HEAD"], | ||
| cwd=_repo_root(), | ||
| capture_output=True, | ||
| check=True, | ||
| text=True, | ||
| ) | ||
| except (FileNotFoundError, subprocess.CalledProcessError): | ||
| return None | ||
| return proc.stdout.strip() or None | ||
|
|
There was a problem hiding this comment.
Address Bandit security warnings for subprocess usage.
The pipeline fails due to Bandit B607 (partial executable path). While this usage is safe (read-only git rev-parse), you can satisfy the security scanner by using an absolute path or documenting the security model.
Consider documenting the security rationale or using shutil.which to resolve the path:
🛡️ Proposed fix to resolve git path
+import shutil
+
def _repo_git_sha() -> Optional[str]:
+ git_path = shutil.which("git")
+ if git_path is None:
+ return None
try:
proc = subprocess.run(
- ["git", "rev-parse", "HEAD"],
+ [git_path, "rev-parse", "HEAD"],
cwd=_repo_root(),
capture_output=True,
check=True,
text=True,
+ timeout=5, # Prevent hanging
)
except (FileNotFoundError, subprocess.CalledProcessError):
return None
+ except subprocess.TimeoutExpired:
+ return None
return proc.stdout.strip() or None🧰 Tools
🪛 GitHub Actions: Release Checks
[error] 114-120: Bandit (B607:start_process_with_partial_path): Starting a process with a partial executable path. Location: tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py:114:15
🪛 Ruff (0.15.7)
[error] 115-115: Starting a process with a partial executable path
(S607)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/transform/pipeline_cache.py` around lines 112
- 124, The Bandit B607 warning comes from invoking the "git" executable by name
in _repo_git_sha; resolve the executable first and use the absolute path or
return None to satisfy the scanner. Update _repo_git_sha to call
shutil.which("git") (import shutil) to get git_path, if git_path is falsy return
None, then call subprocess.run with [git_path, "rev-parse", "HEAD"] (keeping
cwd=_repo_root(), capture_output=True, check=True, text=True) and preserve the
existing exception handling; this explicitly documents/ensures the executable
path is not partial and satisfies the security scanner.
Please do not review, prototype only.
Summary by CodeRabbit
New Features
PipelineCacheConfigsettings (defaults: sharding boundary,~/.cache/tensorrt_llm/auto_deploy_pipeline).Documentation
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.