Skip to content

Commit 5b93a0c

Browse files
committed
Remove unnecessary cuda sync for better perf
Pull Request resolved: #17315 Right now we always do cudasync before existing cudabackend.execution(). However we only need that when copying data from gpu to cpu; any actions happen inside a same stream do not need explicit sync. ghstack-source-id: 339914649 @exported-using-ghexport Differential Revision: [D92193164](https://our.internmc.facebook.com/intern/diff/D92193164/)
1 parent b162f9f commit 5b93a0c

5 files changed

Lines changed: 165 additions & 49 deletions

File tree

backends/aoti/aoti_delegate_handle.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ struct AOTIDelegateHandle {
8484
void* so_handle;
8585
std::string so_path;
8686
AOTInductorModelContainerHandle container_handle;
87-
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
88-
// dependency
87+
void* cuda_stream; // Per-handle CUDA stream. If nullptr, use backend's shared
88+
// stream instead (for skip-copy optimization).
8989
std::string method_name;
9090

9191
// Function pointers specific to this handle's shared library

backends/aoti/slim/core/storage.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,13 @@ struct DeviceTraits<c10::DeviceType::CUDA> {
127127
/// @param ptr Pointer to device memory to free.
128128
static void free(void* ptr) {
129129
// Get the current stream for the current device
130+
// Currently all cuda slimtensors should be on the same device same stream,
131+
// so we can just use the stream on current device.
132+
// TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to
133+
// support multiple devices.
130134
auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1);
131-
if (stream_result.ok()) {
132-
ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get()));
133-
} else {
134-
// Fallback to synchronous free if we can't get the stream
135-
ET_CUDA_LOG_WARN(cudaFree(ptr));
136-
}
135+
ET_CHECK_MSG(stream_result.ok(), "Failed to get current CUDA stream");
136+
ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get()));
137137
}
138138

139139
/// Copies memory between CPU and CUDA or CUDA and CUDA.

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 145 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ using slim::c10::DeviceType;
7777
namespace {
7878
constexpr char kSkipCopyOutputToCpuForMethod[] =
7979
"skip_copy_output_to_cpu_for_method";
80+
constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream";
8081
} // anonymous namespace
8182

8283
class ET_EXPERIMENTAL CudaBackend final
@@ -143,6 +144,36 @@ class ET_EXPERIMENTAL CudaBackend final
143144
return method_in_csv(method_name, skip_copy_method_);
144145
}
145146

147+
// Create the shared CUDA stream. Called when use_shared_cuda_stream option
148+
// is set to true. The presence of shared_cuda_stream_ indicates shared mode.
149+
void create_shared_cuda_stream() {
150+
std::lock_guard<std::mutex> guard(cuda_stream_mutex_);
151+
if (shared_cuda_stream_ != nullptr) {
152+
return; // Already created
153+
}
154+
cudaError_t err = cudaStreamCreate(&shared_cuda_stream_);
155+
if (err != cudaSuccess) {
156+
ET_LOG(
157+
Error,
158+
"Failed to create shared CUDA stream: %s",
159+
cudaGetErrorString(err));
160+
return;
161+
}
162+
ET_LOG(Info, "Created shared CUDA stream: %p", shared_cuda_stream_);
163+
}
164+
165+
// Get the shared CUDA stream. Returns nullptr if not in shared mode.
166+
cudaStream_t get_shared_cuda_stream() const {
167+
std::lock_guard<std::mutex> guard(cuda_stream_mutex_);
168+
return shared_cuda_stream_;
169+
}
170+
171+
// Check if we're using shared CUDA stream mode.
172+
bool is_using_shared_cuda_stream() const {
173+
std::lock_guard<std::mutex> guard(cuda_stream_mutex_);
174+
return shared_cuda_stream_ != nullptr;
175+
}
176+
146177
Error load_function_pointers_into_handle(
147178
void* so_handle,
148179
AOTIDelegateHandle* handle) const {
@@ -181,6 +212,19 @@ class ET_EXPERIMENTAL CudaBackend final
181212
}
182213

183214
public:
215+
// Destructor: clean up the shared CUDA stream if it was created.
216+
~CudaBackend() {
217+
if (shared_cuda_stream_ != nullptr) {
218+
cudaError_t err = cudaStreamDestroy(shared_cuda_stream_);
219+
if (err != cudaSuccess) {
220+
ET_LOG(
221+
Error,
222+
"Failed to destroy shared CUDA stream: %s",
223+
cudaGetErrorString(err));
224+
}
225+
}
226+
}
227+
184228
bool is_available() const override {
185229
return 1;
186230
}
@@ -201,6 +245,15 @@ class ET_EXPERIMENTAL CudaBackend final
201245
kSkipCopyOutputToCpuForMethod);
202246
return Error::InvalidArgument;
203247
}
248+
} else if (std::strcmp(option.key, kUseSharedCudaStream) == 0) {
249+
if (auto* val = std::get_if<bool>(&option.value)) {
250+
if (*val) {
251+
create_shared_cuda_stream();
252+
}
253+
} else {
254+
ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream);
255+
return Error::InvalidArgument;
256+
}
204257
}
205258
}
206259
return Error::Ok;
@@ -313,10 +366,27 @@ class ET_EXPERIMENTAL CudaBackend final
313366
handle->container_handle, static_cast<const uint8_t*>(weights_blob)));
314367
buffer_res->Free();
315368
}
316-
// Create a CUDA stream for asynchronous execution
317-
cudaStream_t cuda_stream;
318-
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
319-
handle->cuda_stream = static_cast<void*>(cuda_stream);
369+
370+
// Use shared CUDA stream if enabled via options, otherwise create one.
371+
// A shared stream ensures proper ordering across multiple methods
372+
// (e.g., encoder, decoder, sampler) when using skip-copy optimization.
373+
if (is_using_shared_cuda_stream()) {
374+
// Shared stream mode: set handle's stream to nullptr.
375+
// The stream will be retrieved from backend in execute().
376+
handle->cuda_stream = nullptr;
377+
ET_LOG(
378+
Info, "Using shared CUDA stream for method %s", method_name.c_str());
379+
} else {
380+
// Per-handle stream mode: each handle owns its own stream.
381+
cudaStream_t cuda_stream;
382+
ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream));
383+
handle->cuda_stream = static_cast<void*>(cuda_stream);
384+
ET_LOG(
385+
Info,
386+
"Created new CUDA stream %p for method %s",
387+
handle->cuda_stream,
388+
method_name.c_str());
389+
}
320390

321391
return (DelegateHandle*)handle; // Return the handle post-processing
322392
}
@@ -351,31 +421,30 @@ class ET_EXPERIMENTAL CudaBackend final
351421
// Process input tensors: convert ETensor (CPU) to SlimTensor (GPU)
352422
for (size_t i = 0; i < n_inputs; i++) {
353423
auto* cpu_tensor = &(args[i]->toTensor());
354-
355-
// Check if input data is already on GPU (skip-copy optimization for
356-
// inputs) This can happen when the caller has pre-staged data on GPU
357-
cudaPointerAttributes attributes{};
358424
const void* data_ptr = cpu_tensor->const_data_ptr();
359-
if (data_ptr != nullptr) {
360-
cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr);
361-
if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) {
362-
// Data is already on GPU - wrap it directly without copy
363-
auto sizes = cpu_tensor->sizes();
364-
auto strides = cpu_tensor->strides();
365-
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
366-
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
367-
368-
gpu_inputs[i] = new SlimTensor(slim::from_blob(
369-
const_cast<void*>(data_ptr),
370-
slim::makeArrayRef(sizes_vec),
371-
slim::makeArrayRef(strides_vec),
372-
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
373-
DEFAULT_CUDA_DEVICE,
374-
0 // storage_offset
375-
));
376-
377-
continue;
378-
}
425+
426+
// Check if input data is already on GPU by looking up cached outputs.
427+
// This avoids calling cudaPointerGetAttributes which is a sync point.
428+
// If the data pointer matches a cached output tensor, we know it's on GPU.
429+
SlimTensor* cached_tensor = find_cached_tensor_by_data_ptr(data_ptr);
430+
if (cached_tensor != nullptr) {
431+
// Data is already on GPU from a previous method's output.
432+
// Wrap it directly without copy using from_blob.
433+
auto sizes = cpu_tensor->sizes();
434+
auto strides = cpu_tensor->strides();
435+
std::vector<int64_t> sizes_vec(sizes.begin(), sizes.end());
436+
std::vector<int64_t> strides_vec(strides.begin(), strides.end());
437+
438+
gpu_inputs[i] = new SlimTensor(slim::from_blob(
439+
const_cast<void*>(data_ptr),
440+
slim::makeArrayRef(sizes_vec),
441+
slim::makeArrayRef(strides_vec),
442+
static_cast<slim::c10::ScalarType>(cpu_tensor->scalar_type()),
443+
DEFAULT_CUDA_DEVICE,
444+
0 // storage_offset
445+
));
446+
447+
continue;
379448
}
380449

381450
// Data is on CPU - use from_etensor to copy to GPU
@@ -406,13 +475,19 @@ class ET_EXPERIMENTAL CudaBackend final
406475
// expects ETensor* as input/output. We avoid changing its signature since
407476
// it's shared with the Metal backend. Instead, we reinterpret_cast
408477
// SlimTensor* to Tensor*
478+
//
479+
// Get the CUDA stream: use handle's stream if set, otherwise get from
480+
// backend's shared stream.
481+
cudaStream_t cuda_stream = handle->cuda_stream != nullptr
482+
? static_cast<cudaStream_t>(handle->cuda_stream)
483+
: get_shared_cuda_stream();
409484
AOTIRuntimeError error = handle->run(
410485
handle->container_handle,
411486
reinterpret_cast<Tensor**>(gpu_inputs.data()),
412487
n_inputs,
413488
reinterpret_cast<Tensor**>(gpu_outputs.data()),
414489
n_outputs,
415-
handle->cuda_stream,
490+
static_cast<void*>(cuda_stream),
416491
nullptr);
417492

418493
ET_CHECK_OR_RETURN_ERROR(
@@ -423,17 +498,16 @@ class ET_EXPERIMENTAL CudaBackend final
423498

424499
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);
425500

426-
// Synchronize CUDA stream to ensure kernel execution is complete
427-
// before accessing output data (either for copy or skip-copy path)
428-
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
429-
cudaError_t sync_err = cudaStreamSynchronize(cuda_stream);
430-
ET_CHECK_OR_RETURN_ERROR(
431-
sync_err == cudaSuccess,
432-
Internal,
433-
"cudaStreamSynchronize failed: %s",
434-
cudaGetErrorString(sync_err));
435-
436501
if (copy_outputs) {
502+
// Synchronize CUDA stream before D2H copy. This is required because
503+
// cudaMemcpy is not stream-ordered and needs the kernel to complete.
504+
cudaError_t sync_err = cudaStreamSynchronize(cuda_stream);
505+
ET_CHECK_OR_RETURN_ERROR(
506+
sync_err == cudaSuccess,
507+
Internal,
508+
"cudaStreamSynchronize failed: %s",
509+
cudaGetErrorString(sync_err));
510+
437511
// Deep copy GPU SlimTensor results back to CPU ETensors
438512
for (size_t i = 0; i < n_outputs; i++) {
439513
auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor());
@@ -448,6 +522,12 @@ class ET_EXPERIMENTAL CudaBackend final
448522
// Skip-copy optimization: point ETensor directly to GPU data.
449523
// The caller is responsible for handling GPU data directly.
450524
//
525+
// No cudaStreamSynchronize needed here because:
526+
// 1. All operations (kernel, allocations, frees) are on the same stream
527+
// 2. cudaFreeAsync is stream-ordered, so CUDA guarantees the kernel
528+
// completes before any memory is freed
529+
// 3. The next execution's operations will also be ordered on this stream
530+
//
451531
// Lifetime management: We cache the newly created GPU tensors and delete
452532
// the previous round's tensors, since they are no longer needed.
453533
{
@@ -495,7 +575,9 @@ class ET_EXPERIMENTAL CudaBackend final
495575
}
496576
}
497577

498-
// Destroy the CUDA stream if it exists
578+
// Destroy the CUDA stream only if this handle owns it (non-null).
579+
// When cuda_stream is nullptr, the handle uses the backend's shared
580+
// stream which is managed by the backend singleton via shared_ptr.
499581
if (handle->cuda_stream != nullptr) {
500582
cudaStream_t cuda_stream = static_cast<cudaStream_t>(handle->cuda_stream);
501583
cudaError_t stream_err = cudaStreamDestroy(cuda_stream);
@@ -541,13 +623,36 @@ class ET_EXPERIMENTAL CudaBackend final
541623
mutable std::mutex skip_copy_method_mutex_;
542624
std::string skip_copy_method_;
543625

626+
// Shared CUDA stream for all methods. When set (non-null), all methods use
627+
// the same stream to ensure proper ordering (critical for skip-copy
628+
// optimization). Created when use_shared_cuda_stream option is set to true.
629+
// Cleaned up in destructor.
630+
mutable std::mutex cuda_stream_mutex_;
631+
cudaStream_t shared_cuda_stream_ = nullptr;
632+
544633
// Cached output tensors for skip-copy optimization.
545634
// When skip-copy is enabled, output SlimTensors are cached here to keep
546635
// the underlying GPU memory alive while the caller processes the results.
547636
// Maps each AOTIDelegateHandle* to its vector of cached output tensors.
548637
mutable std::mutex cached_outputs_mutex_;
549638
mutable std::unordered_map<AOTIDelegateHandle*, std::vector<SlimTensor*>>
550639
cached_outputs_;
640+
641+
// Finds a cached SlimTensor by data pointer.
642+
// Returns the cached SlimTensor if found, nullptr otherwise.
643+
// This is used to detect if input data is already on GPU from a previous
644+
// method's output, avoiding the need for cudaPointerGetAttributes.
645+
SlimTensor* find_cached_tensor_by_data_ptr(const void* data_ptr) const {
646+
std::lock_guard<std::mutex> guard(cached_outputs_mutex_);
647+
for (const auto& [handle, tensors] : cached_outputs_) {
648+
for (SlimTensor* tensor : tensors) {
649+
if (tensor != nullptr && tensor->data_ptr() == data_ptr) {
650+
return tensor;
651+
}
652+
}
653+
}
654+
return nullptr;
655+
}
551656
};
552657

553658
} // namespace executorch::backends::cuda

extension/asr/runner/runner.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ AsrRunner::AsrRunner(
4646
}
4747
}
4848

49+
AsrRunner::~AsrRunner() = default;
50+
4951
bool AsrRunner::is_loaded() const {
5052
return module_ && encoder_method_loaded_ && decoder_method_loaded_ &&
5153
(!sampler_method_present_ || sampler_method_loaded_) && tokenizer_ &&
@@ -121,13 +123,20 @@ Error AsrRunner::load() {
121123
#ifdef CUDA_AVAILABLE
122124
// Skip copying outputs to CPU. When a sampler exists, keep both encoder and
123125
// decoder outputs on device and pass decoder logits directly into sampler.
124-
executorch::runtime::BackendOptions<1> backend_options;
126+
// The backend will automatically create a shared CUDA stream for all methods
127+
// when skip-copy is enabled to ensure proper ordering.
128+
executorch::runtime::BackendOptions<2> backend_options;
125129
std::string skip_methods = kEncoderMethodName;
126130
if (sampler_method_present_) {
127131
skip_methods.append(",").append(kDecoderMethodName);
128132
}
129133
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
130134
"skip_copy_output_to_cpu_for_method", skip_methods.c_str()));
135+
// Enable shared CUDA stream for all methods when skip-copy is used.
136+
// This ensures proper ordering between encoder/decoder/sampler outputs.
137+
ET_CHECK_OK_OR_RETURN_ERROR(
138+
backend_options.set_option("use_shared_cuda_stream", true));
139+
131140
const auto opt_err =
132141
executorch::runtime::set_option("CudaBackend", backend_options.view());
133142
if (opt_err != ::executorch::runtime::Error::Ok) {

extension/asr/runner/runner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class ET_EXPERIMENTAL AsrRunner {
6464
std::optional<std::string> data_path,
6565
const std::string& tokenizer_path);
6666

67+
~AsrRunner();
68+
6769
/**
6870
* Returns true when the module and tokenizer are ready for inference.
6971
*/

0 commit comments

Comments
 (0)