feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL)#4147
feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL)#4147BowenFu wants to merge 9 commits intopytorch:mainfrom
Conversation
|
Hi @BowenFu! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
narendasan
left a comment
There was a problem hiding this comment.
@BowenFu can we split this out in to a PR stack? lets put tta.custom_plugin at the bottom. What I want to focus on is lets say a user already has implemented a custom operator in PyTorch backed by one of these kernels. We want to enable the AOT QDP launch of that kernel without a bunch of boilerplate: Basically this example https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/auto_generate_plugins.html with AOT QDP. There is already facilities for the converter generation and plugin registration from PyTorch Meta Kernel. Once we have that we have a solid base to look at region labeling / manual fusion and other advance usecases.
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def lower_custom_plugin( |
There was a problem hiding this comment.
Can we merge this stuff with already existing plugin autogeneration in torch_tensorrt.dynamo.conversion.plugin? Like we already automate converter generation key'ed on operator name
There was a problem hiding this comment.
I just cleaned up all tta.lower_as/tta.export_as related codes from this PR.
| return lower_custom_plugin_descriptor(ctx, descriptor, trt_inputs, name) | ||
|
|
||
|
|
||
| def register_custom_plugin_qdp( |
There was a problem hiding this comment.
Same here, these facilities already exist, they should be extended not duplicated
|
Also @BowenFu please follow the instructions I sent for how to get added to the CLA |
cc46c6c to
29abd38
Compare
Sure. Will include only tta.custom_plugin in this PR. |
5d036ae to
cf485d2
Compare
493cf05 to
e7b6d3b
Compare
…OT integration Adds the torch_tensorrt.annotation (tta) module for authoring TensorRT PluginV3 (QDP) backends using Triton, CuTile, and CuTeDSL kernels. Key features ------------ - tta.custom_plugin(kernel_spec, meta_impl, **attrs): factory that records a deterministic AOT fingerprint and stores kernel + attrs. meta_impl is required (ValueError if None). - tta.triton / tta.cutile / tta.cutedsl: per-backend KernelSpec helpers. - trt_plugins.custom_op(impl=tta.custom_plugin(...)): full Dynamo converter + QDP registration in one call. - Dynamic-shape support via ShapeExpr / SymInt32 in all three backends. - Multi-config tactic selection: TRT benchmarks per-config PTX at engine-build time and picks the fastest tactic automatically. - Multi-output plugins: meta_impl returning tuple -> Tuple[TensorDesc,...] Bug fixes --------- - register_dynamo_plugin: include weight count in QDP num_inputs so trt.add_constant weight tensors are counted in the plugin schema. - _build_desc_fn: pass num_dynamic = num_inputs - num_weights to _build_meta_impl_desc_fn so meta_impl only receives activation descs. Tests ----- - 55 E2E tests covering: Triton/CuTile/CuTeDSL backends, dynamic shapes, multi-output, attrs, tensor weights, bf16/fp16, 3D inputs, production scale, cross-backend same engine. - 41 unit tests.
d1a1767 to
1ac78b7
Compare
…on belongs to tta-full
- _symbolic.py: guard _ShapeDim.__int__ for dynamic dims; wrap _strides_from_td in try/except AttributeError; fix scalar numel to SymInt32(1) (not 0); fix shape_dim() to use self._shape[dim]; add negative-dim normalisation to stride(); add cdiv divisor > 0 guard - _qdp_utils.py: extend make_qdp_symbol hash from 8 to 16 hex chars (reduces birthday-collision risk); enrich analyze_launch_args error messages with position, role, and total counts - _cutile.py: scope PATH mutation with try/finally restore; cleanup CUBIN tempfile in finally block; scope PTX name replace to .entry directive via re.sub; add 500 KB scan-limit debug log; pass param_binding_indices to _launch_params_from_trt - _layer_metadata.py: add _validate_attr_key() for encoding-time safety; call it in _format_attrs; extract torch_op via raw.find() (handles spaces in path); return None for empty torch_op; add debug log for unexpected inter-token tokens - _specs.py: remove dead _custom_plugin_spec() factory; tighten input_formats/output_formats type to Optional[Sequence[int]] - _recorders.py: remove unused field import; add 3-element grid/block validation in _CuTeDSLLaunchProxy.launch - _custom_op.py: narrow impl type annotation from Any to Optional[CustomPluginSpec] with TYPE_CHECKING guard - tests: add cross-instance op_name test; add fn_specs round-trip test and empty-torch_op None test in test_layer_metadata; add NVBUG comment before expectedFailure; fix _W_COLUMN_SCALE to use _LLM_H; tighten BF16 tolerance 1e-1 → 2e-2, FP16 tolerance 1e-2 → 1e-3
- _descriptor.py: remove _build_attr_params (attrs via field params, never called) and _build_identity_desc_fn (meta_impl is required, None path unreachable) - _qdp_utils.py: remove is_cutedsl_compile_fn (heuristic never called), dtype_token (superseded by inline TRT dtype checks), collect_allowed_formats_for_io (informational only, never drove autotune; documented LIMITATION removed), make_td_from_meta (no callers anywhere in the codebase) - _aot/_cutedsl.py: remove _make_compile_wrapper (dead helper, never called) - _aot/_symbolic.py: remove cdiv (users call triton.cdiv / ct.cdiv directly) - _impl.py: delete entire file — CustomPluginTacticManager superseded by _descriptor.py; zero imports found across the whole codebase - conftest.py: remove _CUDNN_TEST_FILES, _ensure_cudnn_on_ld_path, _IS_CUDNN_SUBPROCESS and the CuDNN subprocess block in pytest_sessionfinish; the referenced test files (test_plugin_e2e.py, test_cudnn_plugin_e2e.py) do not exist — the entire CuDNN path was dead - debug_symint.py, repro_myelin_symintexprs.py: delete standalone debug scripts that were never meant to be in the test tree
Add test_recorders.py and test_errors.py (new), extend test_specs.py
and test_layer_metadata.py with targeted tests for all previously
uncovered branches.
Pure-Python files after this commit:
_errors.py 100% (TTADiagnosticError with/without leaf_op/impl_id)
_recorders.py 100% (all 3 recorder classes: Triton, CuTile, CuTeDSL)
_specs.py 100% (AnnotationMetadata helpers, KernelImplSpec
list-kernel branch + all __post_init__ guards)
_layer_metadata.py 99% (1 logically dead line: tok_idx>=len guard that
requires len==3 and len>=4 simultaneously)
TRT-dependent files (_aot/*, _descriptor.py, _qdp_utils.py, _symbolic.py,
_lowering.py) remain 0–33%: they import tensorrt at module level so the
unit-test process cannot load them; covered by e2e tests only.
narendasan
left a comment
There was a problem hiding this comment.
Marked a bunch of immediate stuff that stood out. the TL;DR is theres a ton of re-implementation here and some stuff (particularly the triton stuff) seems hacky.
My general recommendation is focus on adding aot_impl and perhaps autotune (@bowang007 did you look at autotune at all?) to the existing plugin system which should handle the rest, rather than essentially making a whole second version. If there are limitations you are running into with what is there then I believe that is where the technical discussion should be centered. Like perhaps the locking system (cc: @bowang007)
I would also recommend trying to make the systems for defining launch parameters more generically applicable, then we dont need to do as much work to say add support for pallas or nvrtc kernels.
Also I would recommend that all the sort of kernel encapsulation stuff would nicely fit in a namespace called torch_tensorrt.kernels or torch_tensorrt.dynamo.kernels It would be immediately obvious what you would use the namespace for then.
| requires_output_allocator, | ||
| ) | ||
| if impl is not None: | ||
| impl.register_dynamo_plugin( |
There was a problem hiding this comment.
We dont need a complete second code path for this, lets reuse what we have. For example creating the converter is likely the only user of capability_validator, priority, supports_dynamic_shapes, requires_output_allocator. generate_plugin_converter already creates this converter key'ed on name,
Really I would expect the code to look like:
generate_plugin(op_name) # Generates JIT QDP Plugin
generate_plugin_converter(op_name, capability_validator, priority, supports_dynamic_shapes, requires_output_allocator) # Generates the converter that inserts the QDP plugin
if impl: #this should be kernel_impl probably
impl.generate_plugin_aot_impl() | return getattr(fn, _ANNOTATION_METADATA_ATTR, None) | ||
|
|
||
|
|
||
| # ── Custom kernel specs (Triton / CuTile / CuTeDSL) ────────────────────────── |
There was a problem hiding this comment.
Do we have one for NVRTC?
There was a problem hiding this comment.
@narendasan I would expect one TVM path later after TensorRT actually supports that. NVRTC is too flexible which we need to add lots of constraints on it or bridge work to make it work with the existing QDP path.
There was a problem hiding this comment.
There are users of the plugin system who are explicitly using NVRTC which is why I ask
|
|
||
|
|
||
| @dataclass | ||
| class AnnotationMetadata: |
There was a problem hiding this comment.
Is this needed in this PR?
There was a problem hiding this comment.
Pushed the Metadata related changes to a future PR.
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def set_tta_layer_metadata( |
| @@ -0,0 +1,546 @@ | |||
| """TTA metadata stored on TensorRT ``ILayer`` objects as a plain string. | |||
There was a problem hiding this comment.
Is this needed in this PR or one of the higher level ones in the PR stack?
There was a problem hiding this comment.
layer metadata is not necessary for the functionality. Will defer it to a separate PR. It is used to map back TRT engine layers back to torch codes, which would be used by some other high level annotations (which will come later).
| default_factory=dict, hash=False, compare=False | ||
| ) | ||
|
|
||
| def register_dynamo_plugin( |
There was a problem hiding this comment.
This is mostly redundant, all we need is kernel PTX -> AOT IMPL
| _meta = self.meta_impl | ||
| _num_outputs = self.num_outputs | ||
|
|
||
| # CuTeDSL @cute.jit functions expect cute.Tensor, not torch.Tensor. |
There was a problem hiding this comment.
I think we need this only if users are using the kernel in TensorRT and not in PyTorch as custom op, I dont expect this to be very common. Even in eager execution of manually fused regions in future PRs, I would think the expectation is that there should be some custom op that we generate a fx pass to insert
| # causing a downstream shape mismatch at TRT engine build time. Users with | ||
| # shape-sensitive or rank-sensitive meta_impl functions must pass num_outputs | ||
| # explicitly to custom_plugin() to avoid this. | ||
| def _infer_num_outputs(meta_impl: Callable[..., Any]) -> int: |
There was a problem hiding this comment.
We have this already for custom ops
| return False | ||
|
|
||
|
|
||
| def _build_meta_impl_desc_fn( |
There was a problem hiding this comment.
same here, see:
|
|
||
| # Mapping from TRT dtype enum values to the string tokens accepted by | ||
| # AutoTuneCombination. Populated lazily only when TRT is available. | ||
| _TRT_DTYPE_TOKEN: Dict[Any, str] = ( |
There was a problem hiding this comment.
Just add autotune support to the existing plugin system
Commit 1fc07ff ("fix(annotation): address all code-review findings from 10-reviewer audit") introduced several changes that broke all 55 e2e tests. Bisect identified it as the first bad commit; c3e0142 was the last good. Root cause: _triton.py changed ptx.replace(kernel_name_str, unique_name) to a scoped re.sub that only renames the .entry directive. This left .param declarations with the original kernel name prefix, causing _fix_triton_ptx_for_trt to find zero matching params and raise TTAPluginError when reordering was needed. TRT caught the exception from the aot_impl callback and returned None from add_plugin, which then crashed at plugin_layer.name = name. Reverts the following files to their c3e0142 state: - _aot/_triton.py (PTX rename + bounds-check regressions) - _aot/_cutile.py (same scoped re.sub regression) - _aot/_cutedsl.py (sandbox try/except + tempdir cleanup regressions) - _descriptor.py (_is_symbolic_shape_expr heuristic change) - _qdp_utils.py (make_qdp_symbol hash 8->16 char change) - _symbolic.py (numel/stride/shape_dim guard regressions) - _lowering.py (guard regression) - _impl.py (restored — was removed in later commits but needed here) All 181 tests now pass (126 unit + 55 e2e, 3 xfailed).
…verter infrastructure Address PR review comment (pytorch#4147): register_dynamo_plugin created a parallel converter registration path duplicating the logic already in generate_plugin_converter. Changes: - custom_op(impl=None): unchanged — calls generate_plugin + generate_plugin_converter - custom_op(impl=...): replaces impl.register_dynamo_plugin() with an inline converter that calls register_custom_plugin (QDP reg with weight support) + lower_to_trt (weight injection via trt.add_constant), registered via the same dynamo_tensorrt_converter decorator used by generate_plugin_converter - _generate_plugin_converter: return tuple for multi-output plugins All 181 tests pass (126 unit + 55 e2e, 3 xfailed).
…llow-up PR Address PR review comments pytorch#4147: - Comment 4: Remove AnnotationMetadata / attach_annotation_metadata / get_annotation_metadata from _specs.py and __init__.py — unused in this PR (designed for @tta.export_as which is deferred). - Comment 5: Remove _layer_metadata.py and its set_tta_layer_metadata call in lower_custom_plugin_descriptor — diagnostic-only (TRT engine inspector), non-fatal by design, out of scope for this PR. Both modules are preserved in the backup branch and will be reintroduced in a higher-level diagnostics / export_as PR.
| tta.normalize_impl_to_spec(123) | ||
|
|
||
|
|
||
| if __name__ == "__main__": |
There was a problem hiding this comment.
This test spec seems a little opaque to me. I have no idea what it tests without diving into the code and understand that the spec includes difference specification for each type of kernel. We could make it more straightforward.
| pass | ||
|
|
||
| normalized = tta.normalize_impl_to_spec(tta.cutedsl(kernel)) | ||
| self.assertIsInstance(normalized, tta.CustomPluginSpec) |
There was a problem hiding this comment.
If I understand correctly, this won't include the runtime test right?
How can we make sure that the kernel produces correct output with different config?
There was a problem hiding this comment.
unit tests does not test accuracy. Please check e2e tests in integration folder.
Motivation
Torch-TRT compiles PyTorch models to TensorRT engines, but today there is no first-class path for users who want to replace a subgraph with their own Triton, CuTile, or CuTeDSL kernel inside the compiled engine. The typical workaround — writing a C++ TRT plugin and registering it manually — requires leaving the Python ecosystem, managing separate build systems, and wiring up the plugin registry by hand. This is a significant barrier for researchers and ML engineers who already have high-performance Python kernels.
TensorRT 10.x introduced Quick Deployable Plugins (QDP), which support AOT-compiled Python kernels (
@trtp.aot_impl) that are embedded directly into the TRT engine with no Python required at runtime. This PR adds the descriptor and registration layer that lets users express a custom QDP plugin as a plain Python object and pass it to Torch-TRT — with no changes to any core compiler files.What's in this PR
Public API (
import torch_tensorrt.annotation as tta):tta.triton(launch_fn, configs)TritonSpectta.cutile(launch_fn, arch, configs)CuTileSpectta.cutedsl(launch_fn, configs)CuTeDSLSpectta.custom_plugin(impl)CustomPluginSpecQDP registration (
_custom_plugin/):_descriptor.py—CustomPluginSpecdataclass +custom_plugin()factory; computes a deterministic op name from the kernel function identity + config hash;register_custom_plugin()registers@trtp.register/@trtp.autotune/@trtp.aot_implwith TRT's process-global QDP registry using double-checked locking for xdist safety._lowering.py— lowers aCustomPluginSpecto a TRT plugin layer viactx.net.add_plugin(trtp.op.<ns>.<name>(*inputs), aot=True); injects weight tensors asadd_constantlayers._qdp_utils.py— deterministic op-name derivation, tactic table building, meta-tensor helpers for symbolic shape inference._symbolic.py—SymbolicTensorabstraction for QDP shape/dtype descriptor registration.AOT backends (
_custom_plugin/_aot/):_triton.py— Triton → PTX viatriton.compile; per-config tactic entries._cutile.py— CuTile → cubin viatileiras; sm_100+ only._cutedsl.py— CuTeDSL → PTX/cubin vianvidia-cutlass-dsl.Supporting modules:
_specs.py—TritonSpec,CuTileSpec,CuTeDSLSpec,KernelImplSpecfrozen dataclasses;triton()/cutile()/cutedsl()factories;normalize_impl_to_spec()._layer_metadata.py—set_tta_layer_metadata()helper for stamping TRT layer metadata; encode/decode round-trip for custom plugin attribution._recorders.py— launch-parameter recording for Triton/CuTile/CuTeDSL AOT backends._validation.py— spec and descriptor validation utilities._errors.py—TTADiagnosticErrorstructured error type.Tests (
tests/py/annotation/unit/, CPU-only, 46 tests):test_specs.py— kernel spec construction, validation, cache-key stability.test_specs_custom_plugin.py—CustomPluginSpecandcustom_plugin()factory.test_layer_metadata.py— metadata encode/decode round-trip.Design notes
CustomPluginSpecis a plain frozen dataclass. No hooks into_compiler.py,_TRTInterpreter.py, or any other existing file. The integration point (passing a descriptor to a converter) is left for a follow-up PR.op_nameis derived from a hash of the kernel function identity, config set, and weight count. The same descriptor created in two different processes produces the same name, making engine caching safe.configsdicts produce multiple QDP tactics; TRT's autotuner benchmarks all of them at engine-build time.trt_plugins.custom_opintegration — atorch_tensorrt.dynamo.conversion.pluginsAPI that wires aCustomPluginSpecdirectly to a registeredtorch.librarycustom op, so the TRT lowering path is set up with no manual converter code:TRT's autotuner benchmarks all tactics across both backends and selects the fastest for the target GPU.
Future work
This PR establishes the descriptor and registration layer. The follow-up work:
CustomPluginSpecinto the_TRTInterpreterconverter dispatch so that annotated subgraphs are lowered to the registered QDP op duringtorch_tensorrt.compile.tta.lower_as(impl=..., name=...)context manager that tags subgraph regions duringtorch.exportfor targeted lowering to custom plugins. The intended end-to-end usage looks like:Test plan