Skip to content

Commit 84cb621

Browse files
committed
[et device 1/n] introduce cuda device type
Pull Request resolved: #17400 This diff introduces new device type, cuda, which will be used for further device type support use. Differential Revision: [D92928772](https://our.internmc.facebook.com/intern/diff/D92928772/) ghstack-source-id: 340471758
1 parent aa2f683 commit 84cb621

3 files changed

Lines changed: 126 additions & 11 deletions

File tree

runtime/core/portable_type/device.h

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,33 @@
88

99
#pragma once
1010

11-
#include <executorch/runtime/platform/assert.h>
11+
#include <cstddef>
12+
#include <cstdint>
1213

1314
namespace executorch {
1415
namespace runtime {
1516
namespace etensor {
1617

17-
/// Denotes the specific genre of compute device.
18-
/// Subset of https://github.com/pytorch/pytorch/blob/main/c10/core/Device.h
18+
/// Represents the type of compute device.
19+
/// Note: ExecuTorch Device is distinct from PyTorch Device.
1920
enum class DeviceType : int8_t {
2021
CPU = 0,
22+
CUDA = 1,
2123
};
2224

23-
/// An index representing a specific device; For cpu it should always be -1 or 0
25+
/// Total number of device types, used for fixed-size registry arrays.
26+
constexpr size_t kNumDeviceTypes = 2;
27+
28+
/// An index representing a specific device; e.g. GPU 0 vs GPU 1.
29+
/// -1 means the default/unspecified device for that type.
2430
using DeviceIndex = int8_t;
2531

2632
/**
2733
* An abstraction for the compute device on which a tensor is located.
28-
* ExecuTorch doesn't allow dynamic dispatching based on device, so this type is
29-
* just a skeleton to allow certain kernels that expect device as an
30-
* argument to still be run.
3134
*
32-
* In ExecuTorch this is always expected to be CPU.
35+
* Tensors carry a Device to express where their underlying data resides
36+
* (e.g. CPU host memory vs CUDA device memory). The runtime uses this to
37+
* dispatch memory allocation to the appropriate device allocator.
3338
*/
3439
struct Device final {
3540
using Type = DeviceType;
@@ -39,7 +44,7 @@ struct Device final {
3944
/* implicit */ Device(DeviceType type, DeviceIndex index = -1)
4045
: type_(type), index_(index) {}
4146

42-
/// Returns the type of device this is. Only CPU is supported.
47+
/// Returns the type of device the tensor data resides on.
4348
DeviceType type() const noexcept {
4449
return type_;
4550
}
@@ -49,12 +54,19 @@ struct Device final {
4954
return type_ == DeviceType::CPU;
5055
}
5156

52-
/// Returns the device index. Always 0 if specified or -1 if not provided.
57+
/// Returns the device index, or -1 if default/unspecified.
5358
DeviceIndex index() const noexcept {
54-
ET_CHECK(index_ == 0 || index_ == -1);
5559
return index_;
5660
}
5761

62+
bool operator==(const Device& other) const noexcept {
63+
return type_ == other.type_ && index_ == other.index_;
64+
}
65+
66+
bool operator!=(const Device& other) const noexcept {
67+
return !(*this == other);
68+
}
69+
5870
private:
5971
DeviceType type_;
6072
DeviceIndex index_ = -1;
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/portable_type/device.h>
10+
11+
#include <gtest/gtest.h>
12+
13+
using executorch::runtime::etensor::Device;
14+
using executorch::runtime::etensor::DeviceIndex;
15+
using executorch::runtime::etensor::DeviceType;
16+
using executorch::runtime::etensor::kNumDeviceTypes;
17+
18+
// --- DeviceType enum ---
19+
20+
TEST(DeviceTypeTest, EnumValues) {
21+
EXPECT_EQ(static_cast<int8_t>(DeviceType::CPU), 0);
22+
EXPECT_EQ(static_cast<int8_t>(DeviceType::CUDA), 1);
23+
}
24+
25+
TEST(DeviceTypeTest, NumDeviceTypesCoversAllEnums) {
26+
// kNumDeviceTypes must be large enough to index all defined device types.
27+
EXPECT_GT(kNumDeviceTypes, static_cast<size_t>(DeviceType::CPU));
28+
EXPECT_GT(kNumDeviceTypes, static_cast<size_t>(DeviceType::CUDA));
29+
}
30+
31+
// --- Device: CPU ---
32+
33+
TEST(DeviceTest, CpuDefaultIndex) {
34+
Device d(DeviceType::CPU);
35+
EXPECT_TRUE(d.is_cpu());
36+
EXPECT_EQ(d.type(), DeviceType::CPU);
37+
EXPECT_EQ(d.index(), -1);
38+
}
39+
40+
TEST(DeviceTest, CpuExplicitIndex) {
41+
Device d(DeviceType::CPU, 0);
42+
EXPECT_TRUE(d.is_cpu());
43+
EXPECT_EQ(d.index(), 0);
44+
}
45+
46+
// --- Device: CUDA ---
47+
48+
TEST(DeviceTest, CudaDefaultIndex) {
49+
Device d(DeviceType::CUDA);
50+
EXPECT_FALSE(d.is_cpu());
51+
EXPECT_EQ(d.type(), DeviceType::CUDA);
52+
EXPECT_EQ(d.index(), -1);
53+
}
54+
55+
TEST(DeviceTest, CudaExplicitIndex) {
56+
Device d(DeviceType::CUDA, 0);
57+
EXPECT_EQ(d.index(), 0);
58+
}
59+
60+
// --- Device: equality ---
61+
62+
TEST(DeviceTest, EqualitySameTypeAndIndex) {
63+
EXPECT_EQ(Device(DeviceType::CPU, 0), Device(DeviceType::CPU, 0));
64+
EXPECT_EQ(Device(DeviceType::CUDA, 1), Device(DeviceType::CUDA, 1));
65+
}
66+
67+
TEST(DeviceTest, InequalityDifferentType) {
68+
EXPECT_NE(Device(DeviceType::CPU, 0), Device(DeviceType::CUDA, 0));
69+
}
70+
71+
TEST(DeviceTest, InequalityDifferentIndex) {
72+
EXPECT_NE(Device(DeviceType::CUDA, 0), Device(DeviceType::CUDA, 1));
73+
}
74+
75+
TEST(DeviceTest, EqualityDefaultIndices) {
76+
EXPECT_EQ(Device(DeviceType::CPU), Device(DeviceType::CPU));
77+
EXPECT_EQ(Device(DeviceType::CUDA), Device(DeviceType::CUDA));
78+
EXPECT_NE(Device(DeviceType::CPU), Device(DeviceType::CUDA));
79+
}
80+
81+
// --- Device: implicit construction ---
82+
83+
TEST(DeviceTest, ImplicitConstructionFromDeviceType) {
84+
// Device constructor is implicit, allowing DeviceType → Device conversion.
85+
Device d = DeviceType::CUDA;
86+
EXPECT_EQ(d.index(), -1);
87+
}
88+
89+
// --- Deprecated namespace aliases ---
90+
91+
TEST(DeviceTest, DeprecatedNamespaceAliases) {
92+
// Verify the torch::executor aliases still work.
93+
torch::executor::Device d(torch::executor::DeviceType::CUDA, 0);
94+
EXPECT_EQ(d.index(), 0);
95+
}

runtime/core/portable_type/test/targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ def define_common_targets():
4747
],
4848
)
4949

50+
runtime.cxx_test(
51+
name = "device_test",
52+
srcs = ["device_test.cpp"],
53+
deps = [
54+
"//executorch/runtime/core/portable_type:portable_type",
55+
],
56+
)
57+
5058
runtime.cxx_test(
5159
name = "tensor_impl_test",
5260
srcs = ["tensor_impl_test.cpp"],

0 commit comments

Comments
 (0)