diff --git a/backends/aoti/aoti_delegate_handle.h b/backends/aoti/aoti_delegate_handle.h index b14e02da9ef..2bc6abf9bd1 100644 --- a/backends/aoti/aoti_delegate_handle.h +++ b/backends/aoti/aoti_delegate_handle.h @@ -84,8 +84,6 @@ struct AOTIDelegateHandle { void* so_handle; std::string so_path; AOTInductorModelContainerHandle container_handle; - void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header - // dependency std::string method_name; // Function pointers specific to this handle's shared library diff --git a/backends/aoti/slim/core/storage.h b/backends/aoti/slim/core/storage.h index bd227dbb43a..73c4d32d955 100644 --- a/backends/aoti/slim/core/storage.h +++ b/backends/aoti/slim/core/storage.h @@ -127,16 +127,47 @@ struct DeviceTraits { /// @param ptr Pointer to device memory to free. static void free(void* ptr) { // Get the current stream for the current device + // Currently all cuda slimtensors should be on the same device same stream, + // so we can just use the stream on current device. + // TODO(gasoonjia): add cuda stream as a member of MaybeOwningStorage to + // support multiple devices. auto stream_result = executorch::backends::cuda::getCurrentCUDAStream(-1); - if (stream_result.ok()) { - ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get())); + ET_CHECK_MSG(stream_result.ok(), "Failed to get current CUDA stream"); + ET_CUDA_LOG_WARN(cudaFreeAsync(ptr, stream_result.get())); + } + + /// Copies memory between CPU and CUDA or CUDA and CUDA asynchronously. + /// @param dst Destination pointer. + /// @param src Source pointer. + /// @param nbytes Number of bytes to copy. + /// @param dst_device Destination device. + /// @param src_device Source device. + /// @param stream CUDA stream for async copy. + static void memcpy_async( + void* dst, + const void* src, + size_t nbytes, + const c10::Device& dst_device, + const c10::Device& src_device, + cudaStream_t stream) { + cudaMemcpyKind direction = cudaMemcpyDeviceToDevice; + + if (src_device.is_cpu()) { + direction = cudaMemcpyHostToDevice; + } else if (dst_device.is_cpu()) { + direction = cudaMemcpyDeviceToHost; } else { - // Fallback to synchronous free if we can't get the stream - ET_CUDA_LOG_WARN(cudaFree(ptr)); + ET_CHECK_MSG( + src_device.index() == dst_device.index(), + "CUDA memcpy across different device indices not supported: %d != %d", + static_cast(src_device.index()), + static_cast(dst_device.index())); } + + ET_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, direction, stream)); } - /// Copies memory between CPU and CUDA or CUDA and CUDA. + /// Copies memory between CPU and CUDA or CUDA and CUDA synchronously. /// @param dst Destination pointer. /// @param src Source pointer. /// @param nbytes Number of bytes to copy. diff --git a/backends/cuda/runtime/TARGETS b/backends/cuda/runtime/TARGETS index 173fb95a399..5d1bdff4b0f 100644 --- a/backends/cuda/runtime/TARGETS +++ b/backends/cuda/runtime/TARGETS @@ -95,6 +95,9 @@ runtime.cxx_library( srcs = [ "cuda_backend.cpp", ], + headers = [ + "cuda_delegate_handle.h", + ], # @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole) link_whole = True, supports_python_dlopen = True, diff --git a/backends/cuda/runtime/cuda_backend.cpp b/backends/cuda/runtime/cuda_backend.cpp index 8e1a871ad4c..a55c2b4f4cb 100644 --- a/backends/cuda/runtime/cuda_backend.cpp +++ b/backends/cuda/runtime/cuda_backend.cpp @@ -37,6 +37,7 @@ // Include our shim layer headers #include +#include #include #include #include @@ -77,6 +78,7 @@ using slim::c10::DeviceType; namespace { constexpr char kSkipCopyOutputToCpuForMethod[] = "skip_copy_output_to_cpu_for_method"; +constexpr char kUseSharedCudaStream[] = "use_shared_cuda_stream"; } // anonymous namespace class ET_EXPERIMENTAL CudaBackend final @@ -143,6 +145,33 @@ class ET_EXPERIMENTAL CudaBackend final return method_in_csv(method_name, skip_copy_method_); } + // Create the shared CUDA stream. Called when use_shared_cuda_stream option + // is set to true. The presence of shared_cuda_stream_ indicates shared mode. + void create_cuda_stream() { + std::lock_guard guard(cuda_stream_mutex_); + if (shared_cuda_stream_ != nullptr) { + return; // Already created + } + shared_cuda_stream_ = cuda::create_cuda_stream(); + if (shared_cuda_stream_ == nullptr) { + ET_LOG(Error, "Failed to create shared CUDA stream"); + return; + } + ET_LOG(Info, "Created shared CUDA stream: %p", *shared_cuda_stream_); + } + + // Get the shared CUDA stream. Returns nullptr if not in shared mode. + std::shared_ptr get_shared_cuda_stream() const { + std::lock_guard guard(cuda_stream_mutex_); + return shared_cuda_stream_; + } + + // Check if we're using shared CUDA stream mode. + bool is_using_shared_cuda_stream() const { + std::lock_guard guard(cuda_stream_mutex_); + return shared_cuda_stream_ != nullptr; + } + Error load_function_pointers_into_handle( void* so_handle, AOTIDelegateHandle* handle) const { @@ -201,6 +230,15 @@ class ET_EXPERIMENTAL CudaBackend final kSkipCopyOutputToCpuForMethod); return Error::InvalidArgument; } + } else if (std::strcmp(option.key, kUseSharedCudaStream) == 0) { + if (auto* val = std::get_if(&option.value)) { + if (*val) { + create_cuda_stream(); + } + } else { + ET_LOG(Error, "Option %s must be a boolean.", kUseSharedCudaStream); + return Error::InvalidArgument; + } } } return Error::Ok; @@ -282,7 +320,7 @@ class ET_EXPERIMENTAL CudaBackend final processed->Free(); // Create handle and load function pointers into it - AOTIDelegateHandle* handle = new AOTIDelegateHandle(); + cuda::CudaDelegateHandle* handle = new cuda::CudaDelegateHandle(); handle->so_handle = lib_handle; handle->so_path = so_path.string(); handle->method_name = method_name; @@ -313,10 +351,31 @@ class ET_EXPERIMENTAL CudaBackend final handle->container_handle, static_cast(weights_blob))); buffer_res->Free(); } - // Create a CUDA stream for asynchronous execution - cudaStream_t cuda_stream; - ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&cuda_stream)); - handle->cuda_stream = static_cast(cuda_stream); + + // Use shared CUDA stream if enabled via options, otherwise create one. + // A shared stream ensures proper ordering across multiple methods + // (e.g., encoder, decoder, sampler) when using skip-copy optimization. + if (is_using_shared_cuda_stream()) { + // Shared stream mode: all handles share the same stream. + handle->cuda_stream = get_shared_cuda_stream(); + ET_LOG( + Info, + "Using shared CUDA stream %p for method %s", + handle->get_cuda_stream(), + method_name.c_str()); + } else { + // Per-handle stream mode: each handle owns its own stream. + handle->cuda_stream = cuda::create_cuda_stream(); + if (handle->cuda_stream == nullptr) { + delete handle; + return Error::Internal; + } + ET_LOG( + Info, + "Created new CUDA stream %p for method %s", + handle->get_cuda_stream(), + method_name.c_str()); + } return (DelegateHandle*)handle; // Return the handle post-processing } @@ -326,7 +385,7 @@ class ET_EXPERIMENTAL CudaBackend final BackendExecutionContext& context, DelegateHandle* handle_, Span args) const override { - AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_; size_t n_inputs; handle->get_num_inputs(handle->container_handle, &n_inputs); @@ -334,6 +393,8 @@ class ET_EXPERIMENTAL CudaBackend final size_t n_outputs; handle->get_num_outputs(handle->container_handle, &n_outputs); + setCurrentCUDAStream(handle->get_cuda_stream(), 0); // ADD THIS + ET_CHECK_OR_RETURN_ERROR( n_inputs + n_outputs == args.size(), InvalidArgument, @@ -351,34 +412,37 @@ class ET_EXPERIMENTAL CudaBackend final // Process input tensors: convert ETensor (CPU) to SlimTensor (GPU) for (size_t i = 0; i < n_inputs; i++) { auto* cpu_tensor = &(args[i]->toTensor()); - - // Check if input data is already on GPU (skip-copy optimization for - // inputs) This can happen when the caller has pre-staged data on GPU - cudaPointerAttributes attributes{}; const void* data_ptr = cpu_tensor->const_data_ptr(); - if (data_ptr != nullptr) { - cudaError_t err = cudaPointerGetAttributes(&attributes, data_ptr); - if (err == cudaSuccess && attributes.type == cudaMemoryTypeDevice) { - // Data is already on GPU - wrap it directly without copy - auto sizes = cpu_tensor->sizes(); - auto strides = cpu_tensor->strides(); - std::vector sizes_vec(sizes.begin(), sizes.end()); - std::vector strides_vec(strides.begin(), strides.end()); - - gpu_inputs[i] = new SlimTensor(slim::from_blob( - const_cast(data_ptr), - slim::makeArrayRef(sizes_vec), - slim::makeArrayRef(strides_vec), - static_cast(cpu_tensor->scalar_type()), - DEFAULT_CUDA_DEVICE, - 0 // storage_offset - )); - - continue; - } + + // Check if input data is already on GPU by looking up cached outputs. + // This avoids calling cudaPointerGetAttributes which is a sync point. + // If the data pointer matches a cached output tensor, we know it's on + // GPU. + SlimTensor* cached_tensor = find_cached_tensor_by_data_ptr(data_ptr); + if (cached_tensor != nullptr) { + // Data is already on GPU from a previous method's output. + // Use it directly without copy using from_blob and input etensor + // metadata. We do not direclty used cached_tensor here as gpu_input[i] + // because although the underlying data is the same, the shape and + // strides may be different between the cached tensor and the current + // input tensor. + auto sizes = cpu_tensor->sizes(); + auto strides = cpu_tensor->strides(); + std::vector sizes_vec(sizes.begin(), sizes.end()); + std::vector strides_vec(strides.begin(), strides.end()); + gpu_inputs[i] = new SlimTensor(slim::from_blob( + const_cast(data_ptr), + slim::makeArrayRef(sizes_vec), + slim::makeArrayRef(strides_vec), + static_cast(cpu_tensor->scalar_type()), + DEFAULT_CUDA_DEVICE, + 0 // storage_offset + )); + + continue; } - // Data is on CPU - use from_etensor to copy to GPU + // Data is not cacheed -- it must on CPU - use from_etensor to copy to GPU gpu_inputs[i] = new SlimTensor( from_etensor(*cpu_tensor, CPU_DEVICE, DEFAULT_CUDA_DEVICE)); } @@ -406,13 +470,16 @@ class ET_EXPERIMENTAL CudaBackend final // expects ETensor* as input/output. We avoid changing its signature since // it's shared with the Metal backend. Instead, we reinterpret_cast // SlimTensor* to Tensor* + // + // Get the CUDA stream from the handle. + cudaStream_t cuda_stream = handle->get_cuda_stream(); AOTIRuntimeError error = handle->run( handle->container_handle, reinterpret_cast(gpu_inputs.data()), n_inputs, reinterpret_cast(gpu_outputs.data()), n_outputs, - handle->cuda_stream, + static_cast(cuda_stream), nullptr); ET_CHECK_OR_RETURN_ERROR( @@ -423,31 +490,36 @@ class ET_EXPERIMENTAL CudaBackend final const bool copy_outputs = !should_skip_copy_for_method(handle->method_name); - // Synchronize CUDA stream to ensure kernel execution is complete - // before accessing output data (either for copy or skip-copy path) - cudaStream_t cuda_stream = static_cast(handle->cuda_stream); - cudaError_t sync_err = cudaStreamSynchronize(cuda_stream); - ET_CHECK_OR_RETURN_ERROR( - sync_err == cudaSuccess, - Internal, - "cudaStreamSynchronize failed: %s", - cudaGetErrorString(sync_err)); - if (copy_outputs) { - // Deep copy GPU SlimTensor results back to CPU ETensors + // Deep copy GPU SlimTensor results back to CPU ETensors (async) + size_t total_output_bytes = 0; for (size_t i = 0; i < n_outputs; i++) { auto* cpu_output_tensor = &(args[i + n_inputs]->toTensor()); ET_CHECK_OK_OR_RETURN_ERROR( - copy_slimtensor_to_etensor(gpu_outputs[i], cpu_output_tensor), + copy_slimtensor_to_etensor_async( + gpu_outputs[i], cpu_output_tensor, cuda_stream), "Failed to copy GPU output %zu back to CPU ETensor", i); + total_output_bytes += gpu_outputs[i]->nbytes(); + } + + // Only sync for small outputs (like sampler's single int64). + // Large outputs (e.g., logits) have enough CPU processing time after + // execute() returns for the async copy to complete before the data + // is actually accessed. + // TODO(gasoonjia): Investigate root cause of perf regression with + // unconditional sync and remove this heuristic. + constexpr size_t kSyncThresholdBytes = 1024; // 1KB + if (total_output_bytes < kSyncThresholdBytes) { + cudaStreamSynchronize(cuda_stream); } + // Cleanup gpu_outputs after copying - they are no longer needed delete_slimtensor_vector(gpu_outputs); } else { // Skip-copy optimization: point ETensor directly to GPU data. // The caller is responsible for handling GPU data directly. - // + // Lifetime management: We cache the newly created GPU tensors and delete // the previous round's tensors, since they are no longer needed. { @@ -483,7 +555,7 @@ class ET_EXPERIMENTAL CudaBackend final if (handle_ == nullptr) { return; } - AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_; + cuda::CudaDelegateHandle* handle = (cuda::CudaDelegateHandle*)handle_; // Clean up cached output tensors for this handle { @@ -495,16 +567,10 @@ class ET_EXPERIMENTAL CudaBackend final } } - // Destroy the CUDA stream if it exists - if (handle->cuda_stream != nullptr) { - cudaStream_t cuda_stream = static_cast(handle->cuda_stream); - cudaError_t stream_err = cudaStreamDestroy(cuda_stream); - ET_CHECK_OR_LOG_ERROR( - stream_err == cudaSuccess, - "Failed to destroy CUDA stream: %s", - cudaGetErrorString(stream_err)); - handle->cuda_stream = nullptr; - } + // The CUDA stream is managed by shared_ptr in the handle. + // It will be automatically destroyed when the last handle using it + // is destroyed. Just reset our reference. + handle->cuda_stream.reset(); // NOTE: AOTInductorModelContainerDelete does not work correctly with // multiple .so files. Deleting one container frees shared resources, @@ -541,13 +607,38 @@ class ET_EXPERIMENTAL CudaBackend final mutable std::mutex skip_copy_method_mutex_; std::string skip_copy_method_; + // Shared CUDA stream for all methods. When set (non-null), all methods use + // the same stream to ensure proper ordering (critical for skip-copy + // optimization). Created when use_shared_cuda_stream option is set to true. + // Managed via shared_ptr so it's automatically cleaned up when last handle + // is destroyed. + mutable std::mutex cuda_stream_mutex_; + std::shared_ptr shared_cuda_stream_ = nullptr; + // Cached output tensors for skip-copy optimization. // When skip-copy is enabled, output SlimTensors are cached here to keep // the underlying GPU memory alive while the caller processes the results. - // Maps each AOTIDelegateHandle* to its vector of cached output tensors. + // Maps each CudaDelegateHandle* to its vector of cached output tensors. mutable std::mutex cached_outputs_mutex_; - mutable std::unordered_map> - cached_outputs_; + mutable std:: + unordered_map> + cached_outputs_; + + // Finds a cached SlimTensor by data pointer. + // Returns the cached SlimTensor if found, nullptr otherwise. + // This is used to detect if input data is already on GPU from a previous + // method's output, avoiding the need for cudaPointerGetAttributes. + SlimTensor* find_cached_tensor_by_data_ptr(const void* data_ptr) const { + std::lock_guard guard(cached_outputs_mutex_); + for (const auto& [handle, tensors] : cached_outputs_) { + for (SlimTensor* tensor : tensors) { + if (tensor != nullptr && tensor->data_ptr() == data_ptr) { + return tensor; + } + } + } + return nullptr; + } }; } // namespace executorch::backends::cuda diff --git a/backends/cuda/runtime/cuda_delegate_handle.h b/backends/cuda/runtime/cuda_delegate_handle.h new file mode 100644 index 00000000000..02d3356379f --- /dev/null +++ b/backends/cuda/runtime/cuda_delegate_handle.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace cuda { + +// Shared CUDA stream wrapper with proper RAII cleanup. +// This ensures the stream is destroyed when all handles using it are destroyed. +struct CudaStreamDeleter { + void operator()(cudaStream_t* stream) const { + if (stream != nullptr && *stream != nullptr) { + cudaStreamDestroy(*stream); + } + delete stream; + } +}; + +// Creates a new shared CUDA stream. +// Returns nullptr on failure. +inline std::shared_ptr create_cuda_stream() { + cudaStream_t stream; + cudaError_t err = cudaStreamCreate(&stream); + if (err != cudaSuccess) { + return nullptr; + } + return std::shared_ptr( + new cudaStream_t(stream), CudaStreamDeleter()); +} +// CUDA-specific delegate handle that extends AOTIDelegateHandle. +// This consolidates CUDA stream management into a single location. +struct CudaDelegateHandle : public aoti::AOTIDelegateHandle { + // CUDA stream for this handle, support both shared mode and single mode. + // In shared mode, all cuda delegate handles share the same stream (e.g., for + // skip-copy optimization), they will all hold a reference to the same + // shared_ptr. The stream is automatically destroyed when the last handle is + // destroyed. In single mode, every cuda delegate handle has its own stream. + std::shared_ptr cuda_stream; + + // Get the raw CUDA stream pointer for use in CUDA API calls. + // Returns nullptr if no stream is set. + cudaStream_t get_cuda_stream() const { + return cuda_stream ? *cuda_stream : nullptr; + } + + // Check if this handle has a valid CUDA stream. + bool has_cuda_stream() const { + return cuda_stream != nullptr && *cuda_stream != nullptr; + } +}; + +} // namespace cuda +} // namespace backends +} // namespace executorch diff --git a/backends/cuda/runtime/utils.h b/backends/cuda/runtime/utils.h index c39d8160071..d626fb5853b 100644 --- a/backends/cuda/runtime/utils.h +++ b/backends/cuda/runtime/utils.h @@ -89,7 +89,56 @@ inline executorch::runtime::Error _check_tensor_metadata( } // namespace /** - * Copies data from a SlimTensor to an ETensor. + * Copies data from a SlimTensor to an ETensor asynchronously. + * + * This function converts a SlimTensor back to an ETensor using async copy. + * The ETensor is assumed to always reside on CPU, so this handles both + * CPU→CPU and GPU→CPU copies. The function will resize the ETensor if needed + * and copy the data asynchronously on the provided CUDA stream. + * + * NOTE: The caller must ensure proper synchronization after calling this + * function if the ETensor data is accessed on the CPU side. + * + * @param slim_tensor Pointer to the source SlimTensor (must not be null). + * @param etensor Pointer to the destination ETensor (must not be null). + * @param stream The CUDA stream to use for async copy. + * @return Error::Ok on success, or an appropriate error code on failure. + */ +inline executorch::runtime::Error copy_slimtensor_to_etensor_async( + const executorch::backends::aoti::slim::SlimTensor* slim_tensor, + executorch::runtime::etensor::Tensor* etensor, + cudaStream_t stream) { + _check_tensor_metadata(slim_tensor, etensor); + + // Copy data from SlimTensor to ETensor + // SlimTensor may be on GPU or CPU, ETensor is always on CPU + size_t nbytes = slim_tensor->nbytes(); + if (nbytes > 0) { + void* dst_data = etensor->mutable_data_ptr(); + const void* src_data = slim_tensor->data_ptr(); + + if (slim_tensor->is_cpu()) { + // CPU → CPU copy (always synchronous) + std::memcpy(dst_data, src_data, nbytes); + } else { + // GPU → CPU async copy + executorch::backends::aoti::slim::DeviceTraits< + executorch::backends::aoti::slim::c10::DeviceType::CUDA>:: + memcpy_async( + dst_data, + src_data, + nbytes, + executorch::backends::aoti::slim::CPU_DEVICE, + slim_tensor->device(), + stream); + } + } + + return executorch::runtime::Error::Ok; +} + +/** + * Copies data from a SlimTensor to an ETensor synchronously. * * This function converts a SlimTensor back to an ETensor. The ETensor is * assumed to always reside on CPU, so this handles both CPU→CPU and GPU→CPU @@ -115,7 +164,7 @@ inline executorch::runtime::Error copy_slimtensor_to_etensor( // CPU → CPU copy std::memcpy(dst_data, src_data, nbytes); } else { - // GPU → CPU copy + // GPU → CPU synchronous copy executorch::backends::aoti::slim::DeviceTraits< executorch::backends::aoti::slim::c10::DeviceType::CUDA>:: memcpy( diff --git a/extension/asr/runner/runner.cpp b/extension/asr/runner/runner.cpp index 21ff276bb82..58981f06862 100644 --- a/extension/asr/runner/runner.cpp +++ b/extension/asr/runner/runner.cpp @@ -46,6 +46,8 @@ AsrRunner::AsrRunner( } } +AsrRunner::~AsrRunner() = default; + bool AsrRunner::is_loaded() const { return module_ && encoder_method_loaded_ && decoder_method_loaded_ && (!sampler_method_present_ || sampler_method_loaded_) && tokenizer_ && @@ -121,13 +123,20 @@ Error AsrRunner::load() { #ifdef CUDA_AVAILABLE // Skip copying outputs to CPU. When a sampler exists, keep both encoder and // decoder outputs on device and pass decoder logits directly into sampler. - executorch::runtime::BackendOptions<1> backend_options; + // The backend will automatically create a shared CUDA stream for all methods + // when skip-copy is enabled to ensure proper ordering. + executorch::runtime::BackendOptions<2> backend_options; std::string skip_methods = kEncoderMethodName; if (sampler_method_present_) { skip_methods.append(",").append(kDecoderMethodName); } ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option( "skip_copy_output_to_cpu_for_method", skip_methods.c_str())); + // Enable shared CUDA stream for all methods when skip-copy is used. + // This ensures proper ordering between encoder/decoder/sampler outputs. + ET_CHECK_OK_OR_RETURN_ERROR( + backend_options.set_option("use_shared_cuda_stream", true)); + const auto opt_err = executorch::runtime::set_option("CudaBackend", backend_options.view()); if (opt_err != ::executorch::runtime::Error::Ok) { diff --git a/extension/asr/runner/runner.h b/extension/asr/runner/runner.h index 077fdb69fe4..d8bb8f5c279 100644 --- a/extension/asr/runner/runner.h +++ b/extension/asr/runner/runner.h @@ -64,6 +64,8 @@ class ET_EXPERIMENTAL AsrRunner { std::optional data_path, const std::string& tokenizer_path); + ~AsrRunner(); + /** * Returns true when the module and tokenizer are ready for inference. */