Conversation
Changes: - csrc/event.hpp: Update EventHandle, create_event, stream_wait to use at::cuda::CUDAStream - csrc/deep_ep.hpp: Change comm_stream member to std::optional<at::cuda::CUDAStream>, add make_cuda_stream helper - csrc/deep_ep.cpp: Update all usages of comm_stream and compute_stream to use at::cuda::CUDAStream This aligns with the Paddle phi/api/include/compat at::cuda::CUDAStream interface.
There was a problem hiding this comment.
Pull request overview
This PR updates DeepEP’s Paddle compatibility layer to use at::cuda::CUDAStream (instead of cudaStream_t) for stream-aware operations (events, stream waits, tensor record_stream), aiming to support newer SM90+ related workflows while keeping compilation working locally.
Changes:
- Switch stream-related APIs in
csrc/event.hppand core buffer logic to accept/useat::cuda::CUDAStream. - Store the communication stream as an
at::cuda::CUDAStream(viastd::optional) and add helpers to access the rawcudaStream_t. - Adjust NVSHMEM link flags and extend
.gitignoreentries.
Reviewed changes
Copilot reviewed 4 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| setup.py | Removes explicit linking of libnvshmem_device.a from extra_link_args. |
| csrc/event.hpp | Migrates event/stream wait helpers from cudaStream_t to at::cuda::CUDAStream. |
| csrc/deep_ep.hpp | Changes comm_stream storage/type, adds raw-stream accessor, and introduces make_cuda_stream. |
| csrc/deep_ep.cpp | Wraps Paddle/NCCL raw streams into at::cuda::CUDAStream and updates call sites to use .stream(). |
| .gitignore | Ignores local draft/planning files and .humanize/. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| at::cuda::CUDAStream get_comm_stream() const { | ||
| return comm_stream.value(); | ||
| } | ||
|
|
||
| // Helper to get raw stream for CUDA APIs | ||
| cudaStream_t get_comm_stream_raw() const { | ||
| return comm_stream.value().stream(); |
There was a problem hiding this comment.
Buffer::get_comm_stream() now returns at::cuda::CUDAStream, which is a breaking API change for any call sites expecting a raw cudaStream_t (e.g., the pybind get_comm_stream wrapper in csrc/deep_ep.cpp assigns it to cudaStream_t). Consider either keeping get_comm_stream() returning cudaStream_t and adding a separate get_comm_cuda_stream() (or similar) for the at::cuda::CUDAStream, or update all call sites to use the new get_comm_stream_raw() accessor explicitly.
| at::cuda::CUDAStream get_comm_stream() const { | |
| return comm_stream.value(); | |
| } | |
| // Helper to get raw stream for CUDA APIs | |
| cudaStream_t get_comm_stream_raw() const { | |
| return comm_stream.value().stream(); | |
| // Returns the raw CUDA stream for compatibility with existing call sites. | |
| cudaStream_t get_comm_stream() const { | |
| return comm_stream.value().stream(); | |
| } | |
| // Returns the underlying at::cuda::CUDAStream for callers that need it. | |
| at::cuda::CUDAStream get_comm_cuda_stream() const { | |
| return comm_stream.value(); | |
| } | |
| // Helper to get raw stream for CUDA APIs (alias of get_comm_stream()). | |
| cudaStream_t get_comm_stream_raw() const { | |
| return get_comm_stream(); |
| c10::StreamId sid = static_cast<c10::StreamId>(reinterpret_cast<intptr_t>(raw_stream)); | ||
| return at::cuda::CUDAStream(c10::Stream(c10::Stream::UNSAFE, | ||
| c10::Device(c10::DeviceType::CUDA, device_id), | ||
| sid)); |
There was a problem hiding this comment.
make_cuda_stream() builds a c10::StreamId by reinterpret_cast-ing a cudaStream_t pointer value and then constructs a c10::Stream with c10::Stream::UNSAFE. In PyTorch/C10, StreamId is not a raw CUDA stream pointer; it encodes stream type/index, so this can create an invalid stream identity and lead to incorrect synchronization or allocator/record_stream behavior. Prefer the official external-stream wrapping API (e.g., a getStreamFromExternal(...) helper if available in your at::cuda compat layer) or store/use the raw cudaStream_t directly for CUDA APIs while only using at::cuda::CUDAStream values that are created through supported constructors.
| c10::StreamId sid = static_cast<c10::StreamId>(reinterpret_cast<intptr_t>(raw_stream)); | |
| return at::cuda::CUDAStream(c10::Stream(c10::Stream::UNSAFE, | |
| c10::Device(c10::DeviceType::CUDA, device_id), | |
| sid)); | |
| // Wrap external CUDA stream using the official PyTorch API instead of | |
| // fabricating a c10::StreamId from the raw pointer value. | |
| return at::cuda::getStreamFromExternal( | |
| reinterpret_cast<uint64_t>(raw_stream), | |
| device_id); |
| // Construct at::cuda::CUDAStream from raw cudaStream_t | ||
| cudaStream_t raw_stream = comm_ctx->GetStream(); | ||
| c10::StreamId sid = static_cast<c10::StreamId>(reinterpret_cast<intptr_t>(raw_stream)); | ||
| comm_stream.emplace(c10::Stream(c10::Stream::UNSAFE, | ||
| c10::Device(c10::DeviceType::CUDA, device_id), | ||
| sid)); |
There was a problem hiding this comment.
The constructor wraps comm_ctx->GetStream() by casting the raw cudaStream_t value into a c10::StreamId and constructing a c10::Stream with UNSAFE. This is very likely not a valid way to represent an external CUDA stream in C10/ATen and can break stream semantics. After switching to a supported external-stream wrapper (see make_cuda_stream), use that wrapper here as well to avoid repeating the unsafe conversion logic.
| // Construct at::cuda::CUDAStream from raw cudaStream_t | |
| cudaStream_t raw_stream = comm_ctx->GetStream(); | |
| c10::StreamId sid = static_cast<c10::StreamId>(reinterpret_cast<intptr_t>(raw_stream)); | |
| comm_stream.emplace(c10::Stream(c10::Stream::UNSAFE, | |
| c10::Device(c10::DeviceType::CUDA, device_id), | |
| sid)); | |
| // Construct at::cuda::CUDAStream from raw cudaStream_t using supported wrapper | |
| cudaStream_t raw_stream = comm_ctx->GetStream(); | |
| comm_stream.emplace(make_cuda_stream(raw_stream, device_id)); |
Paddle兼容层新增at::cuda::CUDAStream,用来代替cudaStream_t,本地编译通过,但是由于没有sm_90以上的卡,尚未进行运行时验证
Paddle兼容层的修改详见 PaddlePaddle/Paddle#78143