Skip to content

MD-TRT Support, Compile/Export, C++ and Python #4183

Open
narendasan wants to merge 17 commits intomainfrom
push-vqqzkszwrvyx
Open

MD-TRT Support, Compile/Export, C++ and Python #4183
narendasan wants to merge 17 commits intomainfrom
push-vqqzkszwrvyx

Conversation

@narendasan
Copy link
Copy Markdown
Collaborator

Description

Opening this to test the CI

Fixes # (issue)

Type of change

Please delete options that are not relevant and/or add your own.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

apbose and others added 11 commits April 12, 2026 11:41
- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8
- Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection
- Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback
- Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops
- Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain
- Build: conditional NCCL compilation via torch_nccl toolchain
- Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
…hapes

Five interconnected fixes:

1. fold_get_attr_item_calls: fold scalar param .item() calls into Python
   scalars before AOT tracing. Inside FakeTensorMode, even real-tensor
   .item() calls raise DataDependentOutputException.

2. backends.py: three changes:
   - call fold_get_attr_item_calls before entering FakeTensorMode
   - detect vmap/higher-order ops and route them through aot_autograd
     instead of aot_export_joint_simple (which doesn't handle HOPs)
   - on TRT build failure, strip TRT-only kwargs (use_fp32_acc) from
     the fallback graph before returning it to PyTorch

3. _decompositions.py: prevent SDPA from leaking back into the decomp
   table via Core ATen Interchange ops even after being removed from
   TORCH_TRT_DECOMPOSITIONS.

4. partitioning/common.py: lower the default max dynamic shape from
   min*2^16 to min*2^12 — 65536 is too large for TRT to find kernel
   implementations for attention ops.

5. _TorchTensorRTModule.py: move CPU scalar inputs to CUDA before
   execution — aot_autograd lifts scalar attributes (e.g. head_dim^-0.5)
   as explicit graph inputs; TRT requires all inputs on CUDA.

Also fixes remove_sym_nodes to match tensor sources by equality rather
than local_name so that GetItemSource bases (from torch.compile
dynamic=True) are matched correctly, and updates register_sdpa.py to
handle aten.scaled_dot_product_attention.default (the form produced after
aot_autograd) in addition to the flash/efficient variants.
@meta-cla meta-cla bot added the cla signed label Apr 12, 2026
@github-actions github-actions bot added documentation Improvements or additions to documentation component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: build system Issues re: Build system component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: torch_compile labels Apr 12, 2026
@github-actions github-actions bot requested a review from zewenli98 April 12, 2026 19:09
github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

github-actions[bot]

This comment was marked as outdated.

@narendasan narendasan force-pushed the push-vqqzkszwrvyx branch 5 times, most recently from 473cff9 to 9022e03 Compare April 13, 2026 01:14

std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
// All inputs are expected to be on CUDA. Warn and move any that are not.
for (auto& inp : inputs) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to remove this but didnt have time to check if the device operations in python suppress this correctly

// the constructor-time bind was deferred (e.g. no collective had been issued
// at construction time, or for serialized programs loaded inline where there
// is no Python _TorchTensorRTModule.forward wrapper).
if (compiled_engine->is_md && !compiled_engine->nccl_initialized) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure this is necessary

Comment thread core/runtime/register_jit_hooks.cpp Outdated
// process group from the c10d registry. PyTorch assigns sequential
// numeric names ("0", "1", ...) to process groups; probe until we
// find one with an NCCL backend.
if (this->group_name.empty() && this->is_md) {
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should only do this if there is one available group. If there are multiple NCCL groups available we should tell the user to manually select


def forward(self, x):
out = self.linear(x)
out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets dig into this more after the PR lands

logger = logging.getLogger("torchtrtrun")


def _get_nccl_lib_dir() -> Optional[str]:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move into its own file


self._nccl_comm: Optional[Any] = None
self._has_nccl_ops: bool = False

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be set before the self.setup_engine()

inspector = self.engine.create_engine_inspector()
engine_json = inspector.get_engine_information(trt.LayerInformationFormat.JSON)
self._has_nccl_ops = "NCCL" in engine_json or "AllReduce" in engine_json

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like this works
engine_json_lower = engine_json.lower()
self._has_nccl_ops = "dist_collective" in engine_json_lower or "nccl" in engine_json_lower or "allreduce" in engine_json_lower

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_md.py	2026-04-14 20:34:53.887235+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_md.py	2026-04-14 20:35:10.238612+00:00
@@ -97,13 +97,13 @@
        x = self.out_proj2(self.relu(self.in_proj2(x)))
        return x


def get_model(device_mesh):
-    assert world_size % 2 == 0, (
-        f"TP examples require an even number of GPUs, got {world_size}"
-    )
+    assert (
+        world_size % 2 == 0
+    ), f"TP examples require an even number of GPUs, got {world_size}"
    model = ToyModel().to(DEVICE)
    parallelize_module(
        module=model,
        device_mesh=device_mesh,
        parallelize_plan={
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py	2026-04-14 20:34:53.901309+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/custom_ops_converters.py	2026-04-14 20:35:11.188933+00:00
@@ -22,11 +22,13 @@

if ENABLED_FEATURES.native_trt_collectives:
    # Use native TensorRT DistCollective API (no TensorRT-LLM dependency)
    _LOGGER.info("Using native TensorRT DistCollective API for distributed operations")

-    @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_gather_op, requires_multidevice=True)
+    @dynamo_tensorrt_converter(
+        tensorrt_fused_nccl_all_gather_op, requires_multidevice=True
+    )
    def fused_nccl_gather(
        ctx: ConversionContext,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
@@ -39,11 +41,13 @@
            SourceIR.ATEN,
            name,
            [args[0]],
        )

-    @dynamo_tensorrt_converter(tensorrt_fused_nccl_reduce_scatter_op, requires_multidevice=True)
+    @dynamo_tensorrt_converter(
+        tensorrt_fused_nccl_reduce_scatter_op, requires_multidevice=True
+    )
    def fused_nccl_reduce_scatter(
        ctx: ConversionContext,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
@@ -56,11 +60,13 @@
            SourceIR.ATEN,
            name,
            [args[0]],
        )

-    @dynamo_tensorrt_converter(tensorrt_fused_nccl_all_reduce_op, requires_multidevice=True)
+    @dynamo_tensorrt_converter(
+        tensorrt_fused_nccl_all_reduce_op, requires_multidevice=True
+    )
    def fused_nccl_all_reduce(
        ctx: ConversionContext,
        target: Target,
        args: Tuple[Argument, ...],
        kwargs: Dict[str, Argument],
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2026-04-14 20:34:53.900894+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py	2026-04-14 20:35:11.722086+00:00
@@ -669,11 +669,11 @@
            cuda_engine,
            self._input_names,
            self._output_names,
            self.weight_name_map,
            self.ctx.requires_output_allocator,
-            self.ctx.requires_multidevice
+            self.ctx.requires_multidevice,
        )

    def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
        self._cur_node_name = get_node_name(n)
        self._cur_node = n
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2026-04-14 20:34:53.907339+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py	2026-04-14 20:35:13.284044+00:00
@@ -368,12 +368,18 @@
            # For engines with native NCCL collective layers, all ranks must
            # have a live IExecutionContext before any rank executes a
            # collective. Barrier here so a fast-compiling rank does not race
            # ahead and issue an NCCL op while another rank is still inside
            # deserialize_cuda_engine / create_execution_context.
-            if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
-                logger.debug("Barrier after execution context creation (distributed NCCL engine)")
+            if (
+                dist.is_available()
+                and dist.is_initialized()
+                and dist.get_world_size() > 1
+            ):
+                logger.debug(
+                    "Barrier after execution context creation (distributed NCCL engine)"
+                )
                dist.barrier()

        assert self.context is not None, "Failed to create execution context"
        assert self.engine.num_io_tensors == (
            len(self.input_names) + len(self.output_names)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_native_nccl.py	2026-04-14 20:34:53.931747+00:00
+++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/distributed/test_native_nccl.py	2026-04-14 20:35:15.878751+00:00
@@ -1525,13 +1525,11 @@

        expected = torch.full((1, 4), float(expected_sum), device=device)
        _check_close(out, expected, f"context_switch sg{i+1} rank={rank}")


-def _multirank_pg_migration(
-    rank: int, world_size: int, device: torch.device
-) -> None:
+def _multirank_pg_migration(rank: int, world_size: int, device: torch.device) -> None:
    """Compile with the default world group, run inference, then migrate to a new
    subgroup via distributed_group(new_group, model) and verify that inference
    still produces correct results — i.e. the NCCL communicator is re-bound.

    Tests both the C++ runtime (set_group_name resets nccl_initialized) and the
@@ -1562,13 +1560,11 @@
        def __init__(self, pg_name: str) -> None:
            super().__init__()
            self.pg_name = pg_name

        def forward(self, x: torch.Tensor) -> torch.Tensor:
-            out = torch.ops._c10d_functional.all_reduce.default(
-                x, "sum", self.pg_name
-            )
+            out = torch.ops._c10d_functional.all_reduce.default(x, "sum", self.pg_name)
            return torch.ops._c10d_functional.wait_tensor.default(out)

    inp = torch.full((1, 4), float(rank + 1), device=device)
    expected_sum = world_size * (world_size + 1) // 2
    expected = torch.full((1, 4), float(expected_sum), device=device)
@@ -1600,13 +1596,11 @@
        # lazy setup_nccl_comm() call.
        with distributed_group(subgroup, trt_model) as migrated_model:
            with torch.no_grad():
                out_sub = migrated_model(inp)

-        _check_close(
-            out_sub, expected, f"[{label}] migrated to subgroup rank={rank}"
-        )
+        _check_close(out_sub, expected, f"[{label}] migrated to subgroup rank={rank}")

        # ---- Step 3: set_distributed_group (persistent, outside context) ----
        subgroup2 = dist.new_group(ranks=list(range(world_size)))
        torch_tensorrt.distributed.set_distributed_group(trt_model, subgroup2)
        # _state.pg is NOT set here — Python runtime falls back to world group
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py	2026-04-14 20:34:53.940401+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py	2026-04-14 20:35:16.561193+00:00
@@ -194,11 +194,13 @@

            tokenizer = AutoTokenizer.from_pretrained(args.model)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

-            input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE)
+            input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(
+                DEVICE
+            )
            max_len = input_ids.shape[1] + args.num_tokens

            logger.info("Running uncompiled PyTorch baseline ...")
            torch_tokens = generate(
                model, input_ids.clone(), max_len, tokenizer.eos_token_id
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 20:34:53.940401+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 20:35:16.626752+00:00
@@ -135,16 +135,16 @@
            )
        cfg = model.config
        parallelize_module(model, device_mesh, build_tp_plan(cfg))

    cfg = model.config
-    assert cfg.num_key_value_heads % world_size == 0, (
-        f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
-    )
-    assert cfg.num_attention_heads % world_size == 0, (
-        f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
-    )
+    assert (
+        cfg.num_key_value_heads % world_size == 0
+    ), f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
+    assert (
+        cfg.num_attention_heads % world_size == 0
+    ), f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"

    # After column-sharding Q/K/V, each rank holds num_heads // world_size
    # heads. Patch these so HuggingFace attention reshapes correctly.
    for layer in model.model.layers:
        layer.self_attn.num_heads = cfg.num_attention_heads // world_size
@@ -209,11 +209,11 @@
    parser.add_argument(
        "--sharded_checkpoint",
        type=str,
        default="",
        help="Path to DCP sharded checkpoint (e.g. /mnt/cluster-shared/mixtral_sharded). "
-             "If set, skips HF weight download and loads only this rank's shard.",
+        "If set, skips HF weight download and loads only this rank's shard.",
    )
    args = parser.parse_args()

    device_mesh = init_device_mesh("cuda", (world_size,))

--- /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2026-04-14 20:34:53.941720+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/utils.py	2026-04-14 20:35:17.028688+00:00
@@ -163,11 +163,13 @@
            # block-size padding produces s1*(8+s1-s1%8)>1 guards that the
            # symbolic solver can't verify without concrete values).  Without
            # bounds, dynamo traces symbolically and TRT infers the profile
            # from the first concrete shape it sees.
            torch._dynamo.mark_dynamic(input_seq, 1)
-        position_ids = torch.arange(input_seq.shape[1], device=input_seq.device).unsqueeze(0)
+        position_ids = torch.arange(
+            input_seq.shape[1], device=input_seq.device
+        ).unsqueeze(0)
        if dynamic_seqlen_range is not None:
            torch._dynamo.mark_dynamic(position_ids, 1)
        outputs = model(input_seq, position_ids=position_ids)
        logits = outputs.logits
        next_token_logits = logits[:, -1, :]

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 21:04:44.996426+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_mixtral_llm.py	2026-04-14 21:05:11.648796+00:00
@@ -135,16 +135,16 @@
            )
        cfg = model.config
        parallelize_module(model, device_mesh, build_tp_plan(cfg))

    cfg = model.config
-    assert cfg.num_key_value_heads % world_size == 0, (
-        f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
-    )
-    assert cfg.num_attention_heads % world_size == 0, (
-        f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"
-    )
+    assert (
+        cfg.num_key_value_heads % world_size == 0
+    ), f"num_key_value_heads ({cfg.num_key_value_heads}) not divisible by world_size ({world_size})"
+    assert (
+        cfg.num_attention_heads % world_size == 0
+    ), f"num_attention_heads ({cfg.num_attention_heads}) not divisible by world_size ({world_size})"

    # After column-sharding Q/K/V, each rank holds num_heads // world_size
    # heads. Patch these so HuggingFace attention reshapes correctly.
    for layer in model.model.layers:
        layer.self_attn.num_heads = cfg.num_attention_heads // world_size
@@ -209,11 +209,11 @@
    parser.add_argument(
        "--sharded_checkpoint",
        type=str,
        default="",
        help="Path to DCP sharded checkpoint (e.g. /mnt/cluster-shared/mixtral_sharded). "
-             "If set, skips HF weight download and loads only this rank's shard.",
+        "If set, skips HF weight download and loads only this rank's shard.",
    )
    args = parser.parse_args()

    device_mesh = init_device_mesh("cuda", (world_size,))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the present changes, this example wont run with just torchtrtrun since that uses torchrun. Is the expectation that the user do LD_LIBRARY_PATH and run the test?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or accordingly we need to change the initialization code in this

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp b/tmp/changes.txt
index f5171b4..d4bfb2b 100644
--- a/home/runner/work/TensorRT/TensorRT/core/runtime/TRTEngine.cpp
+++ b/tmp/changes.txt
@@ -634,7 +634,8 @@ void TRTEngine::release_nccl_comm() {
  } else {
    this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
  }
-  TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm");
+  TORCHTRT_CHECK(
+      (exec_ctx.get() != nullptr), "Unable to recreate TensorRT execution context after releasing NCCL comm");
  this->nccl_initialized = false;
  LOG_INFO("NCCL communicator released from engine '" << this->name << "'");
}
ERROR: Some files do not conform to style guidelines

Copy link
Copy Markdown

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_mn.py	2026-04-16 03:29:13.283943+00:00
+++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example_mn.py	2026-04-16 03:29:32.666239+00:00
@@ -196,11 +196,13 @@
            logger.info("Warming up (triggering TRT engine build)...")
            _ = trt_model(inp)
            dist.barrier()
            logger.info("All ranks compiled. Running inference...")

-            with torch_tensorrt.distributed.distributed_context(dist.group.WORLD, trt_model) as dist_model:
+            with torch_tensorrt.distributed.distributed_context(
+                dist.group.WORLD, trt_model
+            ) as dist_model:
                output = dist_model(inp)

            assert (python_result - output).std() < 0.01, "Result mismatch"
            logger.info("JIT compile successful!")

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/distributed/_distributed.py	2026-04-16 03:29:13.300596+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/distributed/_distributed.py	2026-04-16 03:29:33.275892+00:00
@@ -19,10 +19,11 @@
import torch.nn as nn

M = TypeVar("M", bound=nn.Module)

_state = threading.local()
+

def register_md_engine(engine: object) -> None:
    """Register a C++ TRTEngine with is_md=True for teardown tracking.

    Engines are stored on the thread-local ``_state`` so they are scoped
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py	2026-04-16 03:29:13.349066+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_multinode.py	2026-04-16 03:29:39.352181+00:00
@@ -192,13 +192,11 @@

        tokenizer = AutoTokenizer.from_pretrained(args.model)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

-        input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(
-            DEVICE
-        )
+        input_ids = tokenizer(args.prompt, return_tensors="pt")["input_ids"].to(DEVICE)
        max_len = input_ids.shape[1] + args.num_tokens

        logger.info("Running uncompiled PyTorch baseline ...")
        torch_tokens = generate(
            model, input_ids.clone(), max_len, tokenizer.eos_token_id
@@ -212,11 +210,13 @@
        trt_model = compile_torchtrt(model, args)

        # Use distributed_context to manage the NCCL lifecycle.  On __exit__
        # it calls release_nccl_comm() on all engines in the module, making
        # dist.destroy_process_group() safe without manual cleanup ordering.
-        with torch_tensorrt.distributed.distributed_context(dist.group.WORLD, trt_model) as trt_model:
+        with torch_tensorrt.distributed.distributed_context(
+            dist.group.WORLD, trt_model
+        ) as trt_model:
            # Trigger TRT engine build explicitly and barrier so all ranks
            # finish compilation before the generation loop starts.
            logger.info("Warming up TRT model (triggering engine build)...")
            _warmup_ids = input_ids.clone()
            torch._dynamo.mark_dynamic(_warmup_ids, 1)
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_qwen_multinode.py	2026-04-16 03:29:13.349066+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_qwen_multinode.py	2026-04-16 03:29:39.407145+00:00
@@ -210,11 +210,13 @@
        trt_model = compile_torchtrt(model, args)

        # Use distributed_context to manage the NCCL lifecycle.  On __exit__
        # it calls release_nccl_comm() on all tracked MD engines, making
        # dist.destroy_process_group() safe without manual cleanup ordering.
-        with torch_tensorrt.distributed.distributed_context(dist.group.WORLD, trt_model) as trt_model:
+        with torch_tensorrt.distributed.distributed_context(
+            dist.group.WORLD, trt_model
+        ) as trt_model:
            # Trigger TRT engine building explicitly and wait for all ranks to
            # finish before starting the generation loop.  Without this barrier,
            # a slow TRT build on one rank causes the other rank to timeout at
            # the next NCCL collective (NCCL default watchdog = 10 min).
            logger.info("Warming up TRT model (triggering engine build)...")
--- /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_export.py	2026-04-16 03:29:13.349066+00:00
+++ /home/runner/work/TensorRT/TensorRT/tools/llm/tensor_parallel_llama_export.py	2026-04-16 03:29:39.436482+00:00
@@ -390,11 +390,11 @@
            trt_model, _ = load_and_run(input_ids, tokenizer, args)

    # Delete the TRT engine before destroying the process group — the engine
    # holds a reference to the NCCL communicator and will segfault if NCCL is
    # torn down first.
-    #del trt_model
-    #torch.cuda.empty_cache()
+    # del trt_model
+    # torch.cuda.empty_cache()
    dist.destroy_process_group()
    logger.info("Done.")
    # Bypass Python GC — TRT/CUDA destructors can segfault during interpreter shutdown.
    os._exit(0)

if id(engine) not in seen:
seen.add(id(engine))
if getattr(engine, "is_md", False):
engine.set_group_name(group_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be designed in a way that it works for only C++ runtime. Why dont we have it for python runtime? Something like maintaining a group_name and nccl_comm instead of getting the active group and binding it once

if id(engine) not in seen:
seen.add(id(engine))
if getattr(engine, "is_md", False):
engine.set_group_name(group_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this added?

"Use torch_tensorrt.distributed.distributed_context(group) to select a non-default group."
)

backend = pg._get_backend(torch.device("cuda"))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we have a equivalent release_nccl_comm for python runtime as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime component: tests Issues re: Tests component: torch_compile documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants