Multi-Device TensorRT Runtime with Native NCCL Collectives#4157
Multi-Device TensorRT Runtime with Native NCCL Collectives#4157
Conversation
- C++ runtime: NCCL communicator init via c10d, rank/world_size serialization, DynamicOutputAllocator, ABI version bump to 8 - Python runtime: distributed support in PythonTorchTensorRTModule and TorchTensorRTModule, NCCL library auto-detection - Conversion: native TRT DistCollective API (AllGather, ReduceScatter, AllReduce) with TRT-LLM plugin fallback - Graph lowering: fuse c10d_functional collectives + wait_tensor into single ops - Feature detection: native_trt_collectives flag, platform validation, graceful fallback chain - Build: conditional NCCL compilation via torch_nccl toolchain - Examples: tensor_parallel_simple_example.py, tensor_parallel_llama_llm.py
narendasan
left a comment
There was a problem hiding this comment.
The high order bits are there is too much NCCL management happening all over the place and state is not managed tightly enough. Fundamentally setting up NCCL is not our job. The user tells us the information we need (the communicator and what rank the process is on) and we trust them to do the rest.
The runtime modules do not need to care about if nccl is set up as long as the information needed to deserialize and setup the engine is available.
Also seems like the C++ runtime only uses NCCL from c10d but then theres other code which falls back to nccl from the nccl python bindings. What do we actually want to support? just c10d? or both?
| REQUIRES_OUTPUT_ALLOCATOR_IDX, | ||
| RESOURCE_ALLOCATION_STRATEGY_IDX, | ||
| RANK_IDX, | ||
| WORLD_SIZE_IDX, |
There was a problem hiding this comment.
Make sure to bump the ABI version
There was a problem hiding this comment.
Fields that are only sometimes applicable should be prefixed with OPTIONAL_ and need guards
| ? ResourceAllocationStrategy::kDynamic | ||
| : ResourceAllocationStrategy::kStatic)) {} | ||
| : ResourceAllocationStrategy::kStatic)) { | ||
| // Load distributed info if available (backward compatible with older ABI versions) |
There was a problem hiding this comment.
We dont need backwards compat unless some semantic definition changed, just bump the version
| return false; | ||
| } | ||
|
|
||
| if (this->nccl_comm == nullptr) { |
There was a problem hiding this comment.
if it can be avoided, do not let anything be null. Should be a sentinel value or use a smart pointer
| } | ||
|
|
||
| // Set NCCL communicator on TensorRT execution context | ||
| try { |
There was a problem hiding this comment.
We need like a real state machine or some true semantics here. Nothing else in the runtime just try catches
| LOG_INFO(" Current world_size: " << this->world_size); | ||
| LOG_INFO(" Current device_id: " << this->device_info.id); | ||
|
|
||
| try { |
There was a problem hiding this comment.
same here, no nested try catches, you need to provide gaurentees about the state of the system.
| logger.error(f"Failed to set NCCL communicator: {e}") | ||
| raise | ||
|
|
||
| def get_nccl_communicator(self) -> Optional[Any]: |
There was a problem hiding this comment.
we dont need getters and setters if you want people to use this field just make it public
| """Get the NCCL communicator if set.""" | ||
| return self._nccl_comm | ||
|
|
||
| def setup_nccl(self, use_pytorch_comm: bool = True) -> None: |
There was a problem hiding this comment.
This should be on the user
|
|
||
| uid = nccl.UniqueId.from_bytes(bytes(uid_tensor.cpu().numpy())) | ||
|
|
||
| comm = nccl.Communicator.init(world_size, rank, uid) |
There was a problem hiding this comment.
Is this a singleton owned by NCCL or can you have multiple, can a user just give this to us?
| } | ||
| } | ||
|
|
||
| void TRTEngine::init_nccl_comm(const std::string& group_name) { |
There was a problem hiding this comment.
why do we need a method that transparently calls another method?
| LOG_INFO(" Got NCCL backend from ProcessGroup"); | ||
|
|
||
| // Cast the backend to ProcessGroupNCCL | ||
| auto* nccl_pg = dynamic_cast<c10d::ProcessGroupNCCL*>(backend.get()); |
There was a problem hiding this comment.
Try to use smart pointers if possible
| weight_name_map: Optional[dict[Any, Any]] = None, | ||
| requires_output_allocator: bool = False, | ||
| symbolic_shape_expressions: Optional[Dict[str, List[Dict[str, Any]]]] = None, | ||
| rank: int = -1, |
There was a problem hiding this comment.
Think all we need is that this is an md_engine and we can fetch the info from torch.distributed or the env var internally
Uh oh!
There was an error while loading. Please reload this page.