Skip to content

Commit 7a43321

Browse files
committed
refactor: clean up device runtime interfaces and initialization semantics
- Split internal implementation headers into a separate include group - Drop redundant explicit default initialization for Device - Add `impl` suffix to CUDA guard implementation files - Unify Arange initialization via DeviceGuardImpl
1 parent 9a39990 commit 7a43321

35 files changed

Lines changed: 87 additions & 79 deletions

infini_train/include/autograd/comm.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Scatter : public autograd::Function {
3333
private:
3434
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
3535
std::vector<Device> target_gpus_;
36-
Device input_device_ = Device();
36+
Device input_device_;
3737
int64_t dim_ = 0;
3838
};
3939

@@ -52,7 +52,7 @@ class Gather : public autograd::Function {
5252

5353
private:
5454
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
55-
Device target_device_ = Device();
55+
Device target_device_;
5656
std::vector<Device> input_gpus_;
5757
int64_t dim_ = 0;
5858
bool unsqueezed_scalar_ = false;
@@ -76,7 +76,7 @@ class Broadcast : public autograd::Function {
7676
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
7777
std::vector<Device> target_gpus_;
7878
int64_t num_inputs_ = 0;
79-
Device input_device_ = Device();
79+
Device input_device_;
8080
};
8181

8282
class ReduceAddCoalesced : public autograd::Function {
@@ -95,7 +95,7 @@ class ReduceAddCoalesced : public autograd::Function {
9595

9696
private:
9797
const infini_train::nn::parallel::ProcessGroup *pg_ = nullptr;
98-
Device destination_ = Device();
98+
Device destination_;
9999
std::vector<Device> target_gpus_;
100100
int64_t num_inputs_ = 0;
101101
};

infini_train/include/core/device_guard.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace infini_train::core {
1111
class Stream;
1212
class BlasHandle;
1313

14+
// Note(dcj): In the CPU backend, kD2D corresponds to a regular memcpy.
1415
enum class MemcpyKind : int8_t {
1516
kH2D = 0,
1617
kD2H = 1,
@@ -161,7 +162,7 @@ class DeviceGuardImplRegistry {
161162
DeviceGuardImpl *Get(Device::DeviceType type) const;
162163

163164
private:
164-
DeviceGuardImplRegistry();
165+
DeviceGuardImplRegistry() = default;
165166
DeviceGuardImplRegistry(const DeviceGuardImplRegistry &) = delete;
166167
DeviceGuardImplRegistry &operator=(const DeviceGuardImplRegistry &) = delete;
167168

infini_train/include/nn/parallel/data_parallel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class DataParallel : public Module {
2020
private:
2121
int dim_ = 0;
2222
std::vector<Device> devices_;
23-
Device output_device_ = Device();
24-
Device src_device_ = Device();
23+
Device output_device_;
24+
Device src_device_;
2525
};
2626
} // namespace infini_train::nn::parallel

infini_train/include/nn/parallel/pp/pipeline_stage.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class PipelineStage {
4242
int num_stages_ = -1;
4343
int prev_rank_ = -1;
4444
int next_rank_ = -1;
45-
Device device_ = Device();
45+
Device device_;
4646
std::vector<std::shared_ptr<Module>> chunks_;
4747
std::vector<std::vector<int64_t>> recv_shape_;
4848
};

infini_train/include/nn/parallel/work.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class WorkNccl final : public Work {
5858
void SetException(std::exception_ptr e);
5959

6060
private:
61-
Device device_ = Device();
61+
Device device_;
6262
cudaEvent_t ready_event_;
6363
cudaEvent_t done_event_;
6464
ncclComm_t comm_;

infini_train/include/tensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class TensorBuffer {
4848
size_t Size() const;
4949

5050
private:
51-
Device device_ = Device();
51+
Device device_;
5252
size_t size_ = 0;
5353
void *data_ = nullptr;
5454
};

infini_train/src/core/cpu/cpu_guard.cc renamed to infini_train/src/core/cpu/cpu_guard_impl.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1-
#include "infini_train/src/core/cpu/cpu_guard.h"
1+
#include "infini_train/src/core/cpu/cpu_guard_impl.h"
22

33
#include <cstdlib>
44
#include <cstring>
55

6+
#include "glog/logging.h"
7+
8+
#include "infini_train/include/core/device_guard.h"
9+
610
namespace infini_train::core::cpu {
711

812
CpuGuardImpl::CpuGuardImpl() {}
@@ -15,6 +19,13 @@ void CpuGuardImpl::Malloc(void **dev_ptr, size_t size) { *dev_ptr = std::malloc(
1519

1620
void CpuGuardImpl::Free(void *dev_ptr) { std::free(dev_ptr); }
1721

18-
void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) { std::memcpy(dst, src, count); }
22+
void CpuGuardImpl::Memcpy(void *dst, const void *src, size_t count, MemcpyKind kind) {
23+
CHECK(kind == MemcpyKind::kD2D) << "CpuGuardImpl::Memcpy only supports kD2D (host-to-host) memcpy, "
24+
<< "but got MemcpyKind=" << static_cast<int>(kind);
25+
26+
std::memcpy(dst, src, count);
27+
}
28+
29+
INFINI_TRAIN_REGISTER_DEVICE_GUARD_IMPL(Device::DeviceType::kCPU, CpuGuardImpl)
1930

2031
} // namespace infini_train::core::cpu
File renamed without changes.

infini_train/src/core/cuda/cuda_guard.cc renamed to infini_train/src/core/cuda/cuda_guard_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "infini_train/src/core/cuda/cuda_guard.h"
1+
#include "infini_train/src/core/cuda/cuda_guard_impl.h"
22

33
#include <array>
44
#include <cstdint>

infini_train/src/core/cuda/cuda_guard.h renamed to infini_train/src/core/cuda/cuda_guard_impl.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
#include <cstdint>
44

5-
#include "infini_train/include/core/blas_handle.h"
65
#include "infini_train/include/core/device_guard.h"
7-
#include "infini_train/include/core/stream.h"
86
#include "infini_train/include/device.h"
97

8+
namespace infini_train::core {
9+
class Stream;
10+
class BlasHandle;
11+
} // namespace infini_train::core
12+
1013
namespace infini_train::core::cuda {
1114

1215
class CudaGuardImpl : public DeviceGuardImpl {

0 commit comments

Comments
 (0)