From 616ad5bec7cc4f790b8d13a3ca826561780582a8 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 13 Feb 2026 01:56:47 +0000 Subject: [PATCH 1/2] fix: replace unsafe d2h/h2d MemcpyAsync calls with synchronous Memcpy --- infini_train/src/kernels/cuda/concat.cu | 15 ++++--- infini_train/src/kernels/cuda/elementwise.cu | 13 +++--- infini_train/src/kernels/cuda/gather.cu | 18 ++++----- infini_train/src/kernels/cuda/slice.cu | 42 ++++++++++---------- infini_train/src/kernels/cuda/split.cu | 10 ++--- infini_train/src/kernels/cuda/stack.cu | 14 +++---- infini_train/src/kernels/cuda/transform.cu | 6 +-- infini_train/src/nn/init.cc | 23 +++++------ infini_train/src/tensor.cc | 22 +++++----- 9 files changed, 74 insertions(+), 89 deletions(-) diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index 1fa8face..f475fd1e 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -113,12 +113,12 @@ std::shared_ptr ConcatForward(const std::vector> int64_t *device_offsets = nullptr; CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream)); - CUDA_CHECK(cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, + cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMallocAsync(&device_offsets, sizeof(int64_t) * (num_inputs + 1), stream)); - CUDA_CHECK(cudaMemcpyAsync(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpy(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1), + cudaMemcpyHostToDevice)); ConcatForwardKernel<<>>( device_input_ptrs, static_cast(output->DataPtr()), device_offsets, N, D, num_inputs, K_total); @@ -219,12 +219,11 @@ std::vector> ConcatBackward(const std::shared_ptr<<>>( static_cast(grad_output->DataPtr()), device_ptrs, device_offsets, N, D, num_inputs, K_total); diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 1b9fe9eb..be97af97 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -107,7 +107,7 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input auto out_stride_host = ComputeStrides(out_shape); int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), cuda_stream); + CUDA_CHECK(cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), cuda_stream)); int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; device_a_strides = device_buffer + ndim * 0; @@ -123,8 +123,7 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, - cuda_stream); + CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice)); LaunchKernel( [&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) { @@ -134,7 +133,7 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input }, output, inputs...); - cudaFreeAsync(device_buffer, cuda_stream); + CUDA_CHECK(cudaFreeAsync(device_buffer, cuda_stream)); } else { static_assert(sizeof...(inputs) == 1 || sizeof...(inputs) == 2, "LaunchForward currently only supports unary and binary operations."); @@ -538,7 +537,7 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out auto out_stride_host = ComputeStrides(out_shape); int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), stream); + CUDA_CHECK(cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), stream)); int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; device_a_strides = device_buffer + ndim * 0; @@ -554,7 +553,7 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice)); const size_t num_elements = grad_output->NumElements(); @@ -616,7 +615,7 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out }, output_a, inputs...); } - cudaFreeAsync(device_buffer, stream); + CUDA_CHECK(cudaFreeAsync(device_buffer, stream)); } template std::shared_ptr UnaryForward(const std::shared_ptr &input, Func unary_fn) { diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index d318465c..38ae2b35 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -89,12 +89,9 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, int64_t *in_strides_dev = dev_buf + 1 * num_dims; int64_t *out_strides_dev = dev_buf + 2 * num_dims; - CUDA_CHECK( - cudaMemcpyAsync(out_dims_dev, idx_dims.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK( - cudaMemcpyAsync(in_strides_dev, in_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(out_strides_dev, out_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, - stream)); + CUDA_CHECK(cudaMemcpy(out_dims_dev, idx_dims.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(in_strides_dev, in_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(out_strides_dev, out_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice)); const int threads = 256; const int blocks = (total_elements + threads - 1) / threads; @@ -198,11 +195,10 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ int64_t *in_strides_dev = out_dims_dev + n_out; int64_t *out_strides_dev = in_strides_dev + n_in_strides; - CUDA_CHECK(cudaMemcpyAsync(out_dims_dev, idx_dims.data(), n_out * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(in_strides_dev, in_strides.data(), n_in_strides * sizeof(int64_t), - cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(out_strides_dev, out_strides.data(), n_out_strides * sizeof(int64_t), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpy(out_dims_dev, idx_dims.data(), n_out * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(in_strides_dev, in_strides.data(), n_in_strides * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(out_strides_dev, out_strides.data(), n_out_strides * sizeof(int64_t), cudaMemcpyHostToDevice)); const int threads = 256; const int blocks = (int)((total_elements + threads - 1) / threads); diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 29a8f1ae..6af0510e 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -73,21 +73,21 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - cudaMallocAsync(&new_dims_dev, - (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), - stream); + CUDA_CHECK(cudaMallocAsync( + &new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), + stream)); starts_dev = new_dims_dev + ends.size(); steps_dev = starts_dev + starts.size(); input_strides_dev = steps_dev + steps.size(); output_strides_dev = input_strides_dev + dims.size(); - cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); + CUDA_CHECK(cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; @@ -167,21 +167,21 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output const auto &stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - cudaMallocAsync(&new_dims_dev, - (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), - stream); + CUDA_CHECK(cudaMallocAsync( + &new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), + stream)); starts_dev = new_dims_dev + ends.size(); steps_dev = starts_dev + starts.size(); input_strides_dev = steps_dev + steps.size(); output_strides_dev = input_strides_dev + dims.size(); - cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); + CUDA_CHECK(cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; @@ -195,7 +195,7 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output }, "CUDA SliceBackward"); - cudaFreeAsync(new_dims_dev, stream); + CUDA_CHECK(cudaFreeAsync(new_dims_dev, stream)); return grad_input; } diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index ec258976..941cbcff 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -133,18 +133,18 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di void *device_ptr; const T **device_grad_output_ptrs; int64_t *device_H_outs; - cudaMallocAsync(&device_ptr, (sizeof(T *) + sizeof(int64_t)) * num_splits, stream); + CUDA_CHECK(cudaMallocAsync(&device_ptr, (sizeof(T *) + sizeof(int64_t)) * num_splits, stream)); device_grad_output_ptrs = (const T **)(device_ptr); device_H_outs = reinterpret_cast(device_grad_output_ptrs + num_splits); - cudaMemcpyAsync(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, - cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMemcpy(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, + cudaMemcpyHostToDevice)); // init H_out for each split std::vector H_outs(num_splits); for (int i = 0; i < num_splits; ++i) { H_outs[i] = std::min(split_size, H_in - i * split_size); } - cudaMemcpyAsync(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMemcpy(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice)); int64_t total_elements = N * H_in * W; int threads_per_block = 256; @@ -154,7 +154,7 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di static_cast(grad_input->DataPtr()), N, H_in, W, split_size, num_splits, device_H_outs); - cudaFreeAsync(device_ptr, stream); + CUDA_CHECK(cudaFreeAsync(device_ptr, stream)); return grad_input; } diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 56067cb8..788f0145 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -67,14 +67,14 @@ std::shared_ptr StackForward(const std::vector> for (const auto &t : inputs) { host_input_ptrs.push_back(static_cast(t->DataPtr())); } const T **device_input_ptrs; - cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream); - cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, - stream); + CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream)); + CUDA_CHECK(cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, + cudaMemcpyHostToDevice)); StackForwardKernel<<>>( device_input_ptrs, static_cast(output->DataPtr()), N, D, num_inputs); - cudaFreeAsync(device_input_ptrs, stream); + CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream)); }, "CUDA StackForward"); @@ -136,13 +136,13 @@ std::vector> StackBackward(const std::vector &i for (auto &t : grads) { host_ptrs.push_back(static_cast(t->DataPtr())); } T **device_ptrs; - cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream); - cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream)); + CUDA_CHECK(cudaMemcpy(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice)); StackBackwardKernel<<>>( static_cast(grad_output->DataPtr()), device_ptrs, N, D, num_inputs); - cudaFreeAsync(device_ptrs, stream); + CUDA_CHECK(cudaFreeAsync(device_ptrs, stream)); }, "CUDA StackBackward"); diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 62d316b1..b75861f7 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -252,7 +252,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i // Allocate device memory for dims and strides // TODO(zbl): avoid using cudaMalloc? int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 3 * ndim * sizeof(int64_t), stream); + CUDA_CHECK(cudaMallocAsync(&device_buffer, 3 * ndim * sizeof(int64_t), stream)); int64_t *in_dims_dev = device_buffer; int64_t *in_strides_dev = device_buffer + ndim; @@ -263,7 +263,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i host_buffer.insert(host_buffer.end(), in_strides.begin(), in_strides.end()); host_buffer.insert(host_buffer.end(), out_strides.begin(), out_strides.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice)); int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; @@ -278,7 +278,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i }, "CUDA TransposeForward"); - cudaFreeAsync(device_buffer, stream); + CUDA_CHECK(cudaFreeAsync(device_buffer, stream)); return output; } diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 27f473c2..2cfaccab 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -50,9 +50,8 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean core::DeviceGuard guard(device); auto impl = core::GetDeviceGuardImpl(device.type()); - impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, - impl->GetStream(device)); + impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); return tensor; } @@ -143,9 +142,8 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, core::DeviceGuard guard(device); auto impl = core::GetDeviceGuardImpl(device.type()); - impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, - impl->GetStream(device)); + impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); return tensor; } @@ -161,9 +159,8 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { auto impl = core::GetDeviceGuardImpl(device.type()); - impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, - impl->GetStream(device)); + impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); return tensor; } @@ -179,9 +176,8 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { auto impl = core::GetDeviceGuardImpl(device.type()); - impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, - impl->GetStream(device)); + impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); return tensor; } @@ -190,7 +186,7 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { case DATA_TYPE: { \ std::vector buffer(num_elements); \ std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ - impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind, stream); \ + impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind); \ break; \ } @@ -202,7 +198,6 @@ std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Devic auto *impl = core::GetDeviceGuardImpl(device.type()); const core::MemcpyKind kind = device.IsCPU() ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D; - core::Stream *stream = impl->GetStream(device); switch (dtype) { ARANGE_CASE(DataType::kUINT8, uint8_t) diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 774e36ad..32008a9d 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -67,9 +67,8 @@ Tensor::Tensor(const float *data, const std::vector &dims, DataType dty core::DeviceGuard guard(device); auto *impl = core::GetDeviceGuardImpl(device.type()); - impl->MemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, - impl->GetStream(device)); + impl->Memcpy(buffer_->DataPtr(), data, buffer_->Size(), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); } void Tensor::SetData(const Tensor &tensor, size_t offset, bool preserve_data) { @@ -162,16 +161,14 @@ Tensor Tensor::To(Device device) { new_tensor = Tensor(dims_, dtype_, Device()); core::DeviceGuard guard(buffer_device); auto impl = core::GetDeviceGuardImpl(buffer_device.type()); - impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H, - impl->GetStream(buffer_device)); + impl->Memcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H); } else if (buffer_device.type() == Device::DeviceType::kCPU) { new_tensor = Tensor(dims_, dtype_, device); // H2D core::DeviceGuard guard(device); auto *impl = core::GetDeviceGuardImpl(device.type()); - impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, - impl->GetStream(device)); + impl->Memcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D); } else { new_tensor = Tensor(dims_, dtype_, device); // P2P @@ -180,8 +177,7 @@ Tensor Tensor::To(Device device) { // 2. H2D core::DeviceGuard guard(buffer_device); auto *impl = core::GetDeviceGuardImpl(buffer_device.type()); - impl->MemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, - impl->GetStream(buffer_device)); + impl->Memcpy(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D); } if (grad_) { @@ -230,17 +226,17 @@ void Tensor::CopyFrom(const Tensor &src) { if (dst_dev == src_dev) { core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); - impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D, impl->GetStream(dst_dev)); + impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D); } else if (dst_dev.type() == Device::DeviceType::kCPU) { // D2H core::DeviceGuard guard(src_dev); auto *impl = core::GetDeviceGuardImpl(src_dev.type()); - impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H, impl->GetStream(src_dev)); + impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H); } else if (src_dev.type() == Device::DeviceType::kCPU) { // H2D core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); - impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D); } else { // TODO(dcj): maybe support p2p api later // P2P @@ -250,7 +246,7 @@ void Tensor::CopyFrom(const Tensor &src) { // 2. H2D core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); - impl->MemcpyAsync(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + impl->Memcpy(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D); } } From c35455dbda85af5057428ec58ce1fdc1514f86bc Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 13 Feb 2026 09:37:01 +0000 Subject: [PATCH 2/2] fix: replace unsafe d2h/h2d Memcpy calls with synchronous MemcpyAsync + SynchronizeStream --- infini_train/src/core/cuda/cuda_guard_impl.cc | 4 +++ infini_train/src/core/cuda/cuda_guard_impl.h | 2 ++ infini_train/src/kernels/cuda/concat.cu | 8 +++++ infini_train/src/kernels/cuda/elementwise.cu | 15 ++++++-- infini_train/src/kernels/cuda/gather.cu | 10 ++++++ infini_train/src/kernels/cuda/slice.cu | 36 +++++++++++++------ infini_train/src/kernels/cuda/split.cu | 12 +++++-- infini_train/src/kernels/cuda/stack.cu | 17 +++++++-- infini_train/src/kernels/cuda/transform.cu | 8 ++++- infini_train/src/nn/init.cc | 28 ++++++++++----- infini_train/src/tensor.cc | 30 +++++++++++----- 11 files changed, 133 insertions(+), 37 deletions(-) diff --git a/infini_train/src/core/cuda/cuda_guard_impl.cc b/infini_train/src/core/cuda/cuda_guard_impl.cc index f6c42a6d..dfa5e1cc 100644 --- a/infini_train/src/core/cuda/cuda_guard_impl.cc +++ b/infini_train/src/core/cuda/cuda_guard_impl.cc @@ -100,6 +100,10 @@ void CudaGuardImpl::SynchronizeDevice(Device device) const { SetDevice(original_device); } +void CudaGuardImpl::SynchronizeStream(Stream *stream) const { + CUDA_CHECK(cudaStreamSynchronize(dynamic_cast(stream)->cuda_stream())); +} + // blas BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const { CheckCudaDevice(device); diff --git a/infini_train/src/core/cuda/cuda_guard_impl.h b/infini_train/src/core/cuda/cuda_guard_impl.h index 7a64dbe4..cef49fdb 100644 --- a/infini_train/src/core/cuda/cuda_guard_impl.h +++ b/infini_train/src/core/cuda/cuda_guard_impl.h @@ -35,6 +35,8 @@ class CudaGuardImpl final : public DeviceGuardImpl { // sync void SynchronizeDevice(Device device) const override; + void SynchronizeStream(Stream *) const override; + // blas BlasHandle *GetBlasHandle(Device device) const override; diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index f475fd1e..bebdafe7 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -125,6 +125,10 @@ std::shared_ptr ConcatForward(const std::vector> CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream)); CUDA_CHECK(cudaFreeAsync(device_offsets, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA ConcatForward"); @@ -230,6 +234,10 @@ std::vector> ConcatBackward(const std::shared_ptr &output, const Input host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), + cudaMemcpyHostToDevice, cuda_stream)); LaunchKernel( [&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) { @@ -134,6 +135,11 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input output, inputs...); CUDA_CHECK(cudaFreeAsync(device_buffer, cuda_stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(cuda_stream)); } else { static_assert(sizeof...(inputs) == 1 || sizeof...(inputs) == 2, "LaunchForward currently only supports unary and binary operations."); @@ -553,7 +559,8 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); const size_t num_elements = grad_output->NumElements(); @@ -616,6 +623,10 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out output_a, inputs...); } CUDA_CHECK(cudaFreeAsync(device_buffer, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); } template std::shared_ptr UnaryForward(const std::shared_ptr &input, Func unary_fn) { diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index 38ae2b35..e498e015 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -107,6 +107,12 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, "CUDA IndexGatherForward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return out; } @@ -214,6 +220,10 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ "CUDA IndexGatherBackward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); return grad_input; } diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 6af0510e..c2ee80ba 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -81,13 +81,16 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const input_strides_dev = steps_dev + steps.size(); output_strides_dev = input_strides_dev + dims.size(); - CUDA_CHECK(cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); CUDA_CHECK( - cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); CUDA_CHECK( - cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; @@ -103,6 +106,11 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const cudaFreeAsync(new_dims_dev, stream); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return new_tensor; } @@ -175,13 +183,16 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output input_strides_dev = steps_dev + steps.size(); output_strides_dev = input_strides_dev + dims.size(); - CUDA_CHECK(cudaMemcpy(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); CUDA_CHECK( - cudaMemcpy(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); CUDA_CHECK( - cudaMemcpy(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice)); + cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; @@ -197,6 +208,11 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output CUDA_CHECK(cudaFreeAsync(new_dims_dev, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return grad_input; } } // namespace infini_train::kernels::cuda diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index 941cbcff..89f04609 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -137,14 +137,15 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di device_grad_output_ptrs = (const T **)(device_ptr); device_H_outs = reinterpret_cast(device_grad_output_ptrs + num_splits); - CUDA_CHECK(cudaMemcpy(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, - cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, + cudaMemcpyHostToDevice, stream)); // init H_out for each split std::vector H_outs(num_splits); for (int i = 0; i < num_splits; ++i) { H_outs[i] = std::min(split_size, H_in - i * split_size); } - CUDA_CHECK(cudaMemcpy(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpyAsync(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice, stream)); int64_t total_elements = N * H_in * W; int threads_per_block = 256; @@ -156,6 +157,11 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di CUDA_CHECK(cudaFreeAsync(device_ptr, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return grad_input; } diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 788f0145..fe70cc2f 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -68,13 +68,18 @@ std::shared_ptr StackForward(const std::vector> const T **device_input_ptrs; CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream)); - CUDA_CHECK(cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, - cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, + cudaMemcpyHostToDevice, stream)); StackForwardKernel<<>>( device_input_ptrs, static_cast(output->DataPtr()), N, D, num_inputs); CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA StackForward"); @@ -137,12 +142,18 @@ std::vector> StackBackward(const std::vector &i T **device_ptrs; CUDA_CHECK(cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream)); - CUDA_CHECK(cudaMemcpy(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, + stream)); StackBackwardKernel<<>>( static_cast(grad_output->DataPtr()), device_ptrs, N, D, num_inputs); CUDA_CHECK(cudaFreeAsync(device_ptrs, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA StackBackward"); diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index b75861f7..4e9630bd 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -263,7 +263,8 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i host_buffer.insert(host_buffer.end(), in_strides.begin(), in_strides.end()); host_buffer.insert(host_buffer.end(), out_strides.begin(), out_strides.end()); - CUDA_CHECK(cudaMemcpy(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpyAsync(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; @@ -280,6 +281,11 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i CUDA_CHECK(cudaFreeAsync(device_buffer, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return output; } diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 2cfaccab..15144371 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -50,8 +50,10 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean core::DeviceGuard guard(device); auto impl = core::GetDeviceGuardImpl(device.type()); - impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -142,8 +144,10 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, core::DeviceGuard guard(device); auto impl = core::GetDeviceGuardImpl(device.type()); - impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -159,8 +163,10 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { auto impl = core::GetDeviceGuardImpl(device.type()); - impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -176,8 +182,10 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { auto impl = core::GetDeviceGuardImpl(device.type()); - impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -186,7 +194,8 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { case DATA_TYPE: { \ std::vector buffer(num_elements); \ std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ - impl->Memcpy(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind); \ + impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind, stream); \ + impl->SynchronizeStream(stream); \ break; \ } @@ -198,6 +207,7 @@ std::shared_ptr Arange(int64_t start, int64_t end, DataType dtype, Devic auto *impl = core::GetDeviceGuardImpl(device.type()); const core::MemcpyKind kind = device.IsCPU() ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D; + core::Stream *stream = impl->GetStream(device); switch (dtype) { ARANGE_CASE(DataType::kUINT8, uint8_t) diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 32008a9d..06c6354e 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -67,8 +67,10 @@ Tensor::Tensor(const float *data, const std::vector &dims, DataType dty core::DeviceGuard guard(device); auto *impl = core::GetDeviceGuardImpl(device.type()); - impl->Memcpy(buffer_->DataPtr(), data, buffer_->Size(), - device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D); + impl->MemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), + device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, + impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); } void Tensor::SetData(const Tensor &tensor, size_t offset, bool preserve_data) { @@ -161,14 +163,18 @@ Tensor Tensor::To(Device device) { new_tensor = Tensor(dims_, dtype_, Device()); core::DeviceGuard guard(buffer_device); auto impl = core::GetDeviceGuardImpl(buffer_device.type()); - impl->Memcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H); + impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H, + impl->GetStream(buffer_device)); + impl->SynchronizeStream(impl->GetStream(buffer_device)); } else if (buffer_device.type() == Device::DeviceType::kCPU) { new_tensor = Tensor(dims_, dtype_, device); // H2D core::DeviceGuard guard(device); auto *impl = core::GetDeviceGuardImpl(device.type()); - impl->Memcpy(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D); + impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, + impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); } else { new_tensor = Tensor(dims_, dtype_, device); // P2P @@ -177,7 +183,9 @@ Tensor Tensor::To(Device device) { // 2. H2D core::DeviceGuard guard(buffer_device); auto *impl = core::GetDeviceGuardImpl(buffer_device.type()); - impl->Memcpy(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D); + impl->MemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, + impl->GetStream(buffer_device)); + impl->SynchronizeStream(impl->GetStream(buffer_device)); } if (grad_) { @@ -226,17 +234,20 @@ void Tensor::CopyFrom(const Tensor &src) { if (dst_dev == src_dev) { core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); - impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D, impl->GetStream(dst_dev)); + impl->SynchronizeStream(impl->GetStream(dst_dev)); } else if (dst_dev.type() == Device::DeviceType::kCPU) { // D2H core::DeviceGuard guard(src_dev); auto *impl = core::GetDeviceGuardImpl(src_dev.type()); - impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H, impl->GetStream(src_dev)); + impl->SynchronizeStream(impl->GetStream(src_dev)); } else if (src_dev.type() == Device::DeviceType::kCPU) { // H2D core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); - impl->Memcpy(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D); + impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + impl->SynchronizeStream(impl->GetStream(dst_dev)); } else { // TODO(dcj): maybe support p2p api later // P2P @@ -246,7 +257,8 @@ void Tensor::CopyFrom(const Tensor &src) { // 2. H2D core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); - impl->Memcpy(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D); + impl->MemcpyAsync(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + impl->SynchronizeStream(impl->GetStream(dst_dev)); } }