Skip to content

Commit cf5b07f

Browse files
committed
fix: replace unsafe d2h/h2d Memcpy calls with synchronous MemcpyAsync + SynchronizeStream
1 parent 616ad5b commit cf5b07f

11 files changed

Lines changed: 133 additions & 37 deletions

File tree

infini_train/src/core/cuda/cuda_guard_impl.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ void CudaGuardImpl::SynchronizeDevice(Device device) const {
100100
SetDevice(original_device);
101101
}
102102

103+
void CudaGuardImpl::SynchronizeStream(Stream *stream) const {
104+
CUDA_CHECK(cudaStreamSynchronize(dynamic_cast<CudaStream *>(stream)->cuda_stream()));
105+
}
106+
103107
// blas
104108
BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const {
105109
CheckCudaDevice(device);

infini_train/src/core/cuda/cuda_guard_impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class CudaGuardImpl final : public DeviceGuardImpl {
3535
// sync
3636
void SynchronizeDevice(Device device) const override;
3737

38+
void SynchronizeStream(Stream *) const override;
39+
3840
// blas
3941
BlasHandle *GetBlasHandle(Device device) const override;
4042

infini_train/src/kernels/cuda/concat.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ std::shared_ptr<Tensor> ConcatForward(const std::vector<std::shared_ptr<Tensor>>
125125

126126
CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream));
127127
CUDA_CHECK(cudaFreeAsync(device_offsets, stream));
128+
// NOTE(dcj):
129+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
130+
// operations have completed before the host buffers go out of scope.
131+
CUDA_CHECK(cudaStreamSynchronize(stream));
128132
},
129133
"CUDA ConcatForward");
130134

@@ -230,6 +234,10 @@ std::vector<std::shared_ptr<Tensor>> ConcatBackward(const std::shared_ptr<Tensor
230234

231235
CUDA_CHECK(cudaFreeAsync(device_ptrs, stream));
232236
CUDA_CHECK(cudaFreeAsync(device_offsets, stream));
237+
// NOTE(dcj):
238+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
239+
// operations have completed before the host buffers go out of scope.
240+
CUDA_CHECK(cudaStreamSynchronize(stream));
233241
},
234242
"CUDA ConcatBackward");
235243

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ void LaunchForward(Func func, const std::shared_ptr<Tensor> &output, const Input
123123
host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end());
124124
host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end());
125125

126-
CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice));
126+
CUDA_CHECK(cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t),
127+
cudaMemcpyHostToDevice, cuda_stream));
127128

128129
LaunchKernel<BLOCK_SIZE, T>(
129130
[&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) {
@@ -134,6 +135,11 @@ void LaunchForward(Func func, const std::shared_ptr<Tensor> &output, const Input
134135
output, inputs...);
135136

136137
CUDA_CHECK(cudaFreeAsync(device_buffer, cuda_stream));
138+
139+
// NOTE(dcj):
140+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
141+
// operations have completed before the host buffers go out of scope.
142+
CUDA_CHECK(cudaStreamSynchronize(cuda_stream));
137143
} else {
138144
static_assert(sizeof...(inputs) == 1 || sizeof...(inputs) == 2,
139145
"LaunchForward currently only supports unary and binary operations.");
@@ -553,7 +559,8 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
553559
host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end());
554560
host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end());
555561

556-
CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice));
562+
CUDA_CHECK(
563+
cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
557564

558565
const size_t num_elements = grad_output->NumElements();
559566

@@ -616,6 +623,10 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
616623
output_a, inputs...);
617624
}
618625
CUDA_CHECK(cudaFreeAsync(device_buffer, stream));
626+
// NOTE(dcj):
627+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
628+
// operations have completed before the host buffers go out of scope.
629+
CUDA_CHECK(cudaStreamSynchronize(stream));
619630
}
620631

621632
template <typename Func> std::shared_ptr<Tensor> UnaryForward(const std::shared_ptr<Tensor> &input, Func unary_fn) {

infini_train/src/kernels/cuda/gather.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ std::shared_ptr<Tensor> IndexGatherForward(const std::shared_ptr<Tensor> &input,
107107
"CUDA IndexGatherForward");
108108

109109
CUDA_CHECK(cudaFreeAsync(dev_buf, stream));
110+
111+
// NOTE(dcj):
112+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
113+
// operations have completed before the host buffers go out of scope.
114+
CUDA_CHECK(cudaStreamSynchronize(stream));
115+
110116
return out;
111117
}
112118

@@ -214,6 +220,10 @@ std::shared_ptr<Tensor> IndexGatherBackward(const std::shared_ptr<Tensor> &grad_
214220
"CUDA IndexGatherBackward");
215221

216222
CUDA_CHECK(cudaFreeAsync(dev_buf, stream));
223+
// NOTE(dcj):
224+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
225+
// operations have completed before the host buffers go out of scope.
226+
CUDA_CHECK(cudaStreamSynchronize(stream));
217227
return grad_input;
218228
}
219229

infini_train/src/kernels/cuda/slice.cu

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,16 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
8181
input_strides_dev = steps_dev + steps.size();
8282
output_strides_dev = input_strides_dev + dims.size();
8383

84-
CUDA_CHECK(cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
85-
CUDA_CHECK(cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
86-
CUDA_CHECK(cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
8784
CUDA_CHECK(
88-
cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
85+
cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
8986
CUDA_CHECK(
90-
cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
87+
cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
88+
CUDA_CHECK(
89+
cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
90+
CUDA_CHECK(cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t),
91+
cudaMemcpyHostToDevice, stream));
92+
CUDA_CHECK(cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t),
93+
cudaMemcpyHostToDevice, stream));
9194

9295
int threads_per_block = 256;
9396
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
@@ -103,6 +106,11 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
103106

104107
cudaFreeAsync(new_dims_dev, stream);
105108

109+
// NOTE(dcj):
110+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
111+
// operations have completed before the host buffers go out of scope.
112+
CUDA_CHECK(cudaStreamSynchronize(stream));
113+
106114
return new_tensor;
107115
}
108116

@@ -175,13 +183,16 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
175183
input_strides_dev = steps_dev + steps.size();
176184
output_strides_dev = input_strides_dev + dims.size();
177185

178-
CUDA_CHECK(cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
179-
CUDA_CHECK(cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
180-
CUDA_CHECK(cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
181186
CUDA_CHECK(
182-
cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
187+
cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
183188
CUDA_CHECK(
184-
cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice));
189+
cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
190+
CUDA_CHECK(
191+
cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
192+
CUDA_CHECK(cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t),
193+
cudaMemcpyHostToDevice, stream));
194+
CUDA_CHECK(cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t),
195+
cudaMemcpyHostToDevice, stream));
185196

186197
int threads_per_block = 256;
187198
int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block;
@@ -197,6 +208,11 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
197208

198209
CUDA_CHECK(cudaFreeAsync(new_dims_dev, stream));
199210

211+
// NOTE(dcj):
212+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
213+
// operations have completed before the host buffers go out of scope.
214+
CUDA_CHECK(cudaStreamSynchronize(stream));
215+
200216
return grad_input;
201217
}
202218
} // namespace infini_train::kernels::cuda

infini_train/src/kernels/cuda/split.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ std::shared_ptr<Tensor> LaunchSplitBackward(const std::vector<int64_t> &input_di
137137
device_grad_output_ptrs = (const T **)(device_ptr);
138138
device_H_outs = reinterpret_cast<int64_t *>(device_grad_output_ptrs + num_splits);
139139

140-
CUDA_CHECK(cudaMemcpy(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits,
141-
cudaMemcpyHostToDevice));
140+
CUDA_CHECK(cudaMemcpyAsync(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits,
141+
cudaMemcpyHostToDevice, stream));
142142

143143
// init H_out for each split
144144
std::vector<int64_t> H_outs(num_splits);
145145
for (int i = 0; i < num_splits; ++i) { H_outs[i] = std::min(split_size, H_in - i * split_size); }
146146

147-
CUDA_CHECK(cudaMemcpy(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice));
147+
CUDA_CHECK(
148+
cudaMemcpyAsync(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice, stream));
148149

149150
int64_t total_elements = N * H_in * W;
150151
int threads_per_block = 256;
@@ -156,6 +157,11 @@ std::shared_ptr<Tensor> LaunchSplitBackward(const std::vector<int64_t> &input_di
156157

157158
CUDA_CHECK(cudaFreeAsync(device_ptr, stream));
158159

160+
// NOTE(dcj):
161+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
162+
// operations have completed before the host buffers go out of scope.
163+
CUDA_CHECK(cudaStreamSynchronize(stream));
164+
159165
return grad_input;
160166
}
161167

infini_train/src/kernels/cuda/stack.cu

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,18 @@ std::shared_ptr<Tensor> StackForward(const std::vector<std::shared_ptr<Tensor>>
6868

6969
const T **device_input_ptrs;
7070
CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream));
71-
CUDA_CHECK(cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs,
72-
cudaMemcpyHostToDevice));
71+
CUDA_CHECK(cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs,
72+
cudaMemcpyHostToDevice, stream));
7373

7474
StackForwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
7575
device_input_ptrs, static_cast<T *>(output->DataPtr()), N, D, num_inputs);
7676

7777
CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream));
78+
79+
// NOTE(dcj):
80+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
81+
// operations have completed before the host buffers go out of scope.
82+
CUDA_CHECK(cudaStreamSynchronize(stream));
7883
},
7984
"CUDA StackForward");
8085

@@ -137,12 +142,18 @@ std::vector<std::shared_ptr<Tensor>> StackBackward(const std::vector<int64_t> &i
137142

138143
T **device_ptrs;
139144
CUDA_CHECK(cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream));
140-
CUDA_CHECK(cudaMemcpy(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice));
145+
CUDA_CHECK(cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice,
146+
stream));
141147

142148
StackBackwardKernel<<<num_blocks, threads_per_block, 0, stream>>>(
143149
static_cast<const T *>(grad_output->DataPtr()), device_ptrs, N, D, num_inputs);
144150

145151
CUDA_CHECK(cudaFreeAsync(device_ptrs, stream));
152+
153+
// NOTE(dcj):
154+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
155+
// operations have completed before the host buffers go out of scope.
156+
CUDA_CHECK(cudaStreamSynchronize(stream));
146157
},
147158
"CUDA StackBackward");
148159

infini_train/src/kernels/cuda/transform.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ std::shared_ptr<Tensor> TransposeForward(const std::shared_ptr<Tensor> &input, i
263263
host_buffer.insert(host_buffer.end(), in_strides.begin(), in_strides.end());
264264
host_buffer.insert(host_buffer.end(), out_strides.begin(), out_strides.end());
265265

266-
CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice));
266+
CUDA_CHECK(
267+
cudaMemcpyAsync(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream));
267268

268269
int threads_per_block = 256;
269270
int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block;
@@ -280,6 +281,11 @@ std::shared_ptr<Tensor> TransposeForward(const std::shared_ptr<Tensor> &input, i
280281

281282
CUDA_CHECK(cudaFreeAsync(device_buffer, stream));
282283

284+
// NOTE(dcj):
285+
// Synchronize the stream here to ensure all preceding H2D/D2H memcpy
286+
// operations have completed before the host buffers go out of scope.
287+
CUDA_CHECK(cudaStreamSynchronize(stream));
288+
283289
return output;
284290
}
285291

infini_train/src/nn/init.cc

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ std::shared_ptr<Tensor> Normal(const std::shared_ptr<Tensor> &tensor, float mean
5050
core::DeviceGuard guard(device);
5151
auto impl = core::GetDeviceGuardImpl(device.type());
5252

53-
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
54-
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);
53+
impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
54+
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
55+
impl->GetStream(device));
56+
impl->SynchronizeStream(impl->GetStream(device));
5557
return tensor;
5658
}
5759

@@ -142,8 +144,10 @@ std::shared_ptr<Tensor> Uniform(const std::shared_ptr<Tensor> &tensor, float a,
142144
core::DeviceGuard guard(device);
143145
auto impl = core::GetDeviceGuardImpl(device.type());
144146

145-
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
146-
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);
147+
impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
148+
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
149+
impl->GetStream(device));
150+
impl->SynchronizeStream(impl->GetStream(device));
147151

148152
return tensor;
149153
}
@@ -159,8 +163,10 @@ std::shared_ptr<Tensor> Ones(const std::shared_ptr<Tensor> &tensor) {
159163

160164
auto impl = core::GetDeviceGuardImpl(device.type());
161165

162-
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
163-
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);
166+
impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
167+
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
168+
impl->GetStream(device));
169+
impl->SynchronizeStream(impl->GetStream(device));
164170

165171
return tensor;
166172
}
@@ -176,8 +182,10 @@ std::shared_ptr<Tensor> Zeros(const std::shared_ptr<Tensor> &tensor) {
176182

177183
auto impl = core::GetDeviceGuardImpl(device.type());
178184

179-
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
180-
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D);
185+
impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float),
186+
device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D,
187+
impl->GetStream(device));
188+
impl->SynchronizeStream(impl->GetStream(device));
181189

182190
return tensor;
183191
}
@@ -186,7 +194,8 @@ std::shared_ptr<Tensor> Zeros(const std::shared_ptr<Tensor> &tensor) {
186194
case DATA_TYPE: { \
187195
std::vector<TYPE> buffer(num_elements); \
188196
std::iota(buffer.begin(), buffer.end(), static_cast<TYPE>(start)); \
189-
impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind); \
197+
impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind, stream); \
198+
impl->SynchronizeStream(stream); \
190199
break; \
191200
}
192201

@@ -198,6 +207,7 @@ std::shared_ptr<Tensor> Arange(int64_t start, int64_t end, DataType dtype, Devic
198207
auto *impl = core::GetDeviceGuardImpl(device.type());
199208

200209
const core::MemcpyKind kind = device.IsCPU() ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D;
210+
core::Stream *stream = impl->GetStream(device);
201211

202212
switch (dtype) {
203213
ARANGE_CASE(DataType::kUINT8, uint8_t)

0 commit comments

Comments
 (0)