Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions backends/cuda/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,28 @@ runtime.cxx_library(
],
)

runtime.cxx_library(
name = "op__device_copy",
srcs = [
"op__device_copy.cpp",
],
# Constructor needed for op registration.
compiler_flags = ["-Wno-global-constructors"],
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
link_whole = True,
supports_python_dlopen = True,
visibility = ["PUBLIC"],
deps = [
":cuda_allocator",
"//executorch/extension/kernel_util:kernel_util",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/kernel:kernel_includes",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
)

runtime.cxx_library(
name = "cuda_backend",
srcs = [
Expand All @@ -120,6 +142,7 @@ runtime.cxx_library(
":cuda_platform",
":runtime_shims",
":cuda_allocator",
":op__device_copy",
":cuda_platform",
"//executorch/backends/aoti:aoti_common_slim",
"//executorch/backends/aoti/slim/core:slimtensor",
Expand Down
122 changes: 122 additions & 0 deletions backends/cuda/runtime/op__device_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cuda/runtime/cuda_allocator.h>
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace executorch::backends::cuda {

using executorch::aten::Tensor;
using executorch::runtime::Error;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::etensor::DeviceType;

Tensor& _h2d_copy_out(
KernelRuntimeContext& ctx,
const Tensor& self,
Tensor& out) {
const auto* self_impl = self.unsafeGetTensorImpl();
const auto* out_impl = out.unsafeGetTensorImpl();
const auto device_index = out_impl->device_index();

ET_KERNEL_CHECK_MSG(
ctx,
self_impl->device_type() == DeviceType::CPU,
InvalidArgument,
out,
"_h2d_copy: source tensor must be on CPU, got device_type=%d",
static_cast<int>(self_impl->device_type()));

ET_KERNEL_CHECK_MSG(
ctx,
out_impl->device_type() == DeviceType::CUDA,
InvalidArgument,
out,
"_h2d_copy: destination tensor must be on CUDA, got device_type=%d",
static_cast<int>(out_impl->device_type()));

const size_t nbytes = self.nbytes();
ET_KERNEL_CHECK_MSG(
ctx,
nbytes == out.nbytes(),
InvalidArgument,
out,
"_h2d_copy: size mismatch: self.nbytes()=%zu, out.nbytes()=%zu",
nbytes,
out.nbytes());

const Error err = CudaAllocator::instance().copy_host_to_device(
out.mutable_data_ptr(), self.const_data_ptr(), nbytes, device_index);
ET_KERNEL_CHECK_MSG(
ctx,
err == Error::Ok,
Internal,
out,
"_h2d_copy: copy_host_to_device failed");

return out;
}

Tensor& _d2h_copy_out(
KernelRuntimeContext& ctx,
const Tensor& self,
Tensor& out) {
const auto* self_impl = self.unsafeGetTensorImpl();
const auto* out_impl = out.unsafeGetTensorImpl();
const auto device_index = self_impl->device_index();

ET_KERNEL_CHECK_MSG(
ctx,
self_impl->device_type() == DeviceType::CUDA,
InvalidArgument,
out,
"_d2h_copy: source tensor must be on CUDA, got device_type=%d",
static_cast<int>(self_impl->device_type()));

ET_KERNEL_CHECK_MSG(
ctx,
out_impl->device_type() == DeviceType::CPU,
InvalidArgument,
out,
"_d2h_copy: destination tensor must be on CPU, got device_type=%d",
static_cast<int>(out_impl->device_type()));

const size_t nbytes = self.nbytes();
ET_KERNEL_CHECK_MSG(
ctx,
nbytes == out.nbytes(),
InvalidArgument,
out,
"_d2h_copy: size mismatch: self.nbytes()=%zu, out.nbytes()=%zu",
nbytes,
out.nbytes());

const Error err = CudaAllocator::instance().copy_device_to_host(
out.mutable_data_ptr(), self.const_data_ptr(), nbytes, device_index);
ET_KERNEL_CHECK_MSG(
ctx,
err == Error::Ok,
Internal,
out,
"_d2h_copy: copy_device_to_host failed");

return out;
}

} // namespace executorch::backends::cuda

EXECUTORCH_LIBRARY(
et_copy,
"_h2d_copy.out",
executorch::backends::cuda::_h2d_copy_out);
EXECUTORCH_LIBRARY(
et_copy,
"_d2h_copy.out",
executorch::backends::cuda::_d2h_copy_out);
21 changes: 21 additions & 0 deletions backends/cuda/runtime/shims/tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,24 @@ def define_common_targets():
cuda_shim_cpp_unittest("aoti_torch_new_tensor_handle")
cuda_shim_cpp_unittest("aoti_torch_item_bool")
cuda_shim_cpp_unittest("aoti_torch_assign_tensors_out")

cpp_unittest(
name = "test_op__device_copy",
srcs = ["test_op__device_copy.cpp"],
deps = [
"//executorch/backends/cuda/runtime:cuda_allocator",
"//executorch/backends/cuda/runtime:op__device_copy",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/portable_type:portable_type",
"//executorch/runtime/kernel:kernel_runtime_context",
"//executorch/runtime/platform:platform",
],
external_deps = [
("cuda", None, "cuda-lazy"),
],
preprocessor_flags = ["-DCUDA_AVAILABLE=1"],
keep_gpu_sections = True,
remote_execution = re_test_utils.remote_execution(
platform = "gpu-remote-execution",
),
)
192 changes: 192 additions & 0 deletions backends/cuda/runtime/shims/tests/test_op__device_copy.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <cuda_runtime.h>
#include <executorch/backends/cuda/runtime/cuda_allocator.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/portable_type/tensor_impl.h>
#include <executorch/runtime/kernel/kernel_runtime_context.h>
#include <executorch/runtime/platform/runtime.h>
#include <gtest/gtest.h>

#if (defined(__has_feature) && __has_feature(address_sanitizer)) || \
defined(__SANITIZE_ADDRESS__)
#include <sanitizer/lsan_interface.h>
#define EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE 1
#else
#define EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE 0
#endif

#include <cstdint>
#include <memory>
#include <vector>

using executorch::aten::ScalarType;
using executorch::aten::Tensor;
using executorch::aten::TensorImpl;
using executorch::backends::cuda::CudaAllocator;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::TensorShapeDynamism;
using executorch::runtime::etensor::DeviceIndex;
using executorch::runtime::etensor::DeviceType;

namespace executorch::backends::cuda {
Tensor& _h2d_copy_out(
KernelRuntimeContext& ctx,
const Tensor& self,
Tensor& out);
Tensor& _d2h_copy_out(
KernelRuntimeContext& ctx,
const Tensor& self,
Tensor& out);
} // namespace executorch::backends::cuda

namespace {

struct CudaDeleter {
void operator()(void* ptr) const {
CudaAllocator::instance().deallocate(ptr, device_index);
}

DeviceIndex device_index = 0;
};

using CudaPtr = std::unique_ptr<void, CudaDeleter>;

CudaPtr allocate_cuda(size_t nbytes, DeviceIndex device_index = 0) {
auto result = CudaAllocator::instance().allocate(nbytes, device_index);
EXPECT_TRUE(result.ok()) << "CudaAllocator::allocate failed";
return CudaPtr(
result.ok() ? result.get() : nullptr, CudaDeleter{device_index});
}

bool is_cuda_available() {
#if EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE
__lsan_disable();
#endif
int device_count = 0;
const cudaError_t err = cudaGetDeviceCount(&device_count);
#if EXECUTORCH_CUDA_DEVICE_COPY_HAS_LSAN_INTERFACE
__lsan_enable();
#endif
return err == cudaSuccess && device_count > 0;
}

std::vector<float> copy_cuda_to_host(const void* device_ptr, size_t numel) {
std::vector<float> host(numel);
const cudaError_t err = cudaMemcpy(
host.data(), device_ptr, numel * sizeof(float), cudaMemcpyDeviceToHost);
EXPECT_EQ(err, cudaSuccess) << "cudaMemcpy D2H failed";
return host;
}

void copy_host_to_cuda(const std::vector<float>& host, void* device_ptr) {
const cudaError_t err = cudaMemcpy(
device_ptr,
host.data(),
host.size() * sizeof(float),
cudaMemcpyHostToDevice);
EXPECT_EQ(err, cudaSuccess) << "cudaMemcpy H2D failed";
}

class CudaDeviceCopyOpTest : public ::testing::Test {
protected:
static void SetUpTestSuite() {
executorch::runtime::runtime_init();
}

void SetUp() override {
if (!is_cuda_available()) {
GTEST_SKIP() << "CUDA not available, skipping CUDA device copy op tests";
}
}
};

} // namespace

TEST_F(CudaDeviceCopyOpTest, H2dCopyUsesCudaAllocatorCopy) {
std::vector<float> src_data = {1.0f, 2.0f, 3.0f, 4.0f};
auto device_data = allocate_cuda(src_data.size() * sizeof(float));
ASSERT_NE(device_data.get(), nullptr);

int32_t sizes[] = {static_cast<int32_t>(src_data.size())};
uint8_t dim_order[] = {0};
int32_t strides[] = {1};

TensorImpl src_impl(
ScalarType::Float,
1,
sizes,
src_data.data(),
dim_order,
strides,
TensorShapeDynamism::STATIC,
DeviceType::CPU,
0);
Tensor src(&src_impl);

TensorImpl dst_impl(
ScalarType::Float,
1,
sizes,
device_data.get(),
dim_order,
strides,
TensorShapeDynamism::STATIC,
DeviceType::CUDA,
0);
Tensor dst(&dst_impl);

KernelRuntimeContext ctx;
Tensor& result = executorch::backends::cuda::_h2d_copy_out(ctx, src, dst);

EXPECT_EQ(&result, &dst);
EXPECT_EQ(copy_cuda_to_host(device_data.get(), src_data.size()), src_data);
}

TEST_F(CudaDeviceCopyOpTest, D2hCopyUsesCudaAllocatorCopy) {
const std::vector<float> expected = {5.0f, 6.0f, 7.0f, 8.0f};
auto device_data = allocate_cuda(expected.size() * sizeof(float));
ASSERT_NE(device_data.get(), nullptr);
copy_host_to_cuda(expected, device_data.get());

std::vector<float> dst_data(expected.size(), 0.0f);
int32_t sizes[] = {static_cast<int32_t>(expected.size())};
uint8_t dim_order[] = {0};
int32_t strides[] = {1};

TensorImpl src_impl(
ScalarType::Float,
1,
sizes,
device_data.get(),
dim_order,
strides,
TensorShapeDynamism::STATIC,
DeviceType::CUDA,
0);
Tensor src(&src_impl);

TensorImpl dst_impl(
ScalarType::Float,
1,
sizes,
dst_data.data(),
dim_order,
strides,
TensorShapeDynamism::STATIC,
DeviceType::CPU,
0);
Tensor dst(&dst_impl);

KernelRuntimeContext ctx;
Tensor& result = executorch::backends::cuda::_d2h_copy_out(ctx, src, dst);

EXPECT_EQ(&result, &dst);
EXPECT_EQ(dst_data, expected);
}
Loading
Loading