Skip to content

Multi-Device TensorRT Runtime with Native NCCL Collectives#4157

Draft
apbose wants to merge 1 commit intomainfrom
abose/trt_MD_cpp_runtime
Draft

Multi-Device TensorRT Runtime with Native NCCL Collectives#4157
apbose wants to merge 1 commit intomainfrom
abose/trt_MD_cpp_runtime

Conversation

@apbose
Copy link
Copy Markdown
Collaborator

@apbose apbose commented Apr 1, 2026

  • C++ runtime: NCCL communicator init via c10d, rank/world_size serialization
  • Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection
  • Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback
  • Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops
  • Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain
  • Build: conditional NCCL compilation via torch_nccl toolchain
  • Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py

- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8
- Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection
- Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback
- Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops
- Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain
- Build: conditional NCCL compilation via torch_nccl toolchain
- Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
@meta-cla meta-cla bot added the cla signed label Apr 1, 2026
@apbose apbose marked this pull request as draft April 1, 2026 23:52
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: converters Issues re: Specific op converters component: build system Issues re: Build system component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Apr 1, 2026
Copy link
Copy Markdown
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The high order bits are there is too much NCCL management happening all over the place and state is not managed tightly enough. Fundamentally setting up NCCL is not our job. The user tells us the information we need (the communicator and what rank the process is on) and we trust them to do the rest.

The runtime modules do not need to care about if nccl is set up as long as the information needed to deserialize and setup the engine is available.

Also seems like the C++ runtime only uses NCCL from c10d but then theres other code which falls back to nccl from the nccl python bindings. What do we actually want to support? just c10d? or both?

REQUIRES_OUTPUT_ALLOCATOR_IDX,
RESOURCE_ALLOCATION_STRATEGY_IDX,
RANK_IDX,
WORLD_SIZE_IDX,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to bump the ABI version

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fields that are only sometimes applicable should be prefixed with OPTIONAL_ and need guards

? ResourceAllocationStrategy::kDynamic
: ResourceAllocationStrategy::kStatic)) {}
: ResourceAllocationStrategy::kStatic)) {
// Load distributed info if available (backward compatible with older ABI versions)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need backwards compat unless some semantic definition changed, just bump the version

return false;
}

if (this->nccl_comm == nullptr) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it can be avoided, do not let anything be null. Should be a sentinel value or use a smart pointer

}

// Set NCCL communicator on TensorRT execution context
try {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need like a real state machine or some true semantics here. Nothing else in the runtime just try catches

LOG_INFO(" Current world_size: " << this->world_size);
LOG_INFO(" Current device_id: " << this->device_info.id);

try {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, no nested try catches, you need to provide gaurentees about the state of the system.

logger.error(f"Failed to set NCCL communicator: {e}")
raise

def get_nccl_communicator(self) -> Optional[Any]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we dont need getters and setters if you want people to use this field just make it public

"""Get the NCCL communicator if set."""
return self._nccl_comm

def setup_nccl(self, use_pytorch_comm: bool = True) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be on the user


uid = nccl.UniqueId.from_bytes(bytes(uid_tensor.cpu().numpy()))

comm = nccl.Communicator.init(world_size, rank, uid)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a singleton owned by NCCL or can you have multiple, can a user just give this to us?

}
}

void TRTEngine::init_nccl_comm(const std::string& group_name) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a method that transparently calls another method?

LOG_INFO(" Got NCCL backend from ProcessGroup");

// Cast the backend to ProcessGroupNCCL
auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to use smart pointers if possible

weight_name_map: Optional[dict[Any, Any]] = None,
requires_output_allocator: bool = False,
symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None,
rank: int = -1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think all we need is that this is an md_engine and we can fetch the info from torch.distributed or the env var internally

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: build system Issues re: Build system component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: runtime

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants