diff --git a/infini_train/src/core/cuda/cuda_guard_impl.cc b/infini_train/src/core/cuda/cuda_guard_impl.cc index f6c42a6d..dfa5e1cc 100644 --- a/infini_train/src/core/cuda/cuda_guard_impl.cc +++ b/infini_train/src/core/cuda/cuda_guard_impl.cc @@ -100,6 +100,10 @@ void CudaGuardImpl::SynchronizeDevice(Device device) const { SetDevice(original_device); } +void CudaGuardImpl::SynchronizeStream(Stream *stream) const { + CUDA_CHECK(cudaStreamSynchronize(dynamic_cast(stream)->cuda_stream())); +} + // blas BlasHandle *CudaGuardImpl::GetBlasHandle(Device device) const { CheckCudaDevice(device); diff --git a/infini_train/src/core/cuda/cuda_guard_impl.h b/infini_train/src/core/cuda/cuda_guard_impl.h index 7a64dbe4..cef49fdb 100644 --- a/infini_train/src/core/cuda/cuda_guard_impl.h +++ b/infini_train/src/core/cuda/cuda_guard_impl.h @@ -35,6 +35,8 @@ class CudaGuardImpl final : public DeviceGuardImpl { // sync void SynchronizeDevice(Device device) const override; + void SynchronizeStream(Stream *) const override; + // blas BlasHandle *GetBlasHandle(Device device) const override; diff --git a/infini_train/src/kernels/cuda/concat.cu b/infini_train/src/kernels/cuda/concat.cu index 1fa8face..bebdafe7 100644 --- a/infini_train/src/kernels/cuda/concat.cu +++ b/infini_train/src/kernels/cuda/concat.cu @@ -113,18 +113,22 @@ std::shared_ptr ConcatForward(const std::vector> int64_t *device_offsets = nullptr; CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream)); - CUDA_CHECK(cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpy(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, + cudaMemcpyHostToDevice)); CUDA_CHECK(cudaMallocAsync(&device_offsets, sizeof(int64_t) * (num_inputs + 1), stream)); - CUDA_CHECK(cudaMemcpyAsync(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpy(device_offsets, host_offsets.data(), sizeof(int64_t) * (num_inputs + 1), + cudaMemcpyHostToDevice)); ConcatForwardKernel<<>>( device_input_ptrs, static_cast(output->DataPtr()), device_offsets, N, D, num_inputs, K_total); CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream)); CUDA_CHECK(cudaFreeAsync(device_offsets, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA ConcatForward"); @@ -219,18 +223,21 @@ std::vector> ConcatBackward(const std::shared_ptr<<>>( static_cast(grad_output->DataPtr()), device_ptrs, device_offsets, N, D, num_inputs, K_total); CUDA_CHECK(cudaFreeAsync(device_ptrs, stream)); CUDA_CHECK(cudaFreeAsync(device_offsets, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA ConcatBackward"); diff --git a/infini_train/src/kernels/cuda/elementwise.cu b/infini_train/src/kernels/cuda/elementwise.cu index 1b9fe9eb..dd6cab76 100644 --- a/infini_train/src/kernels/cuda/elementwise.cu +++ b/infini_train/src/kernels/cuda/elementwise.cu @@ -107,7 +107,7 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input auto out_stride_host = ComputeStrides(out_shape); int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), cuda_stream); + CUDA_CHECK(cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), cuda_stream)); int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; device_a_strides = device_buffer + ndim * 0; @@ -123,8 +123,8 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, - cuda_stream); + CUDA_CHECK(cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), + cudaMemcpyHostToDevice, cuda_stream)); LaunchKernel( [&](dim3 grid, dim3 block, size_t offset, const T *a_ptr, const T *b_ptr) { @@ -134,7 +134,12 @@ void LaunchForward(Func func, const std::shared_ptr &output, const Input }, output, inputs...); - cudaFreeAsync(device_buffer, cuda_stream); + CUDA_CHECK(cudaFreeAsync(device_buffer, cuda_stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(cuda_stream)); } else { static_assert(sizeof...(inputs) == 1 || sizeof...(inputs) == 2, "LaunchForward currently only supports unary and binary operations."); @@ -538,7 +543,7 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out auto out_stride_host = ComputeStrides(out_shape); int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), stream); + CUDA_CHECK(cudaMallocAsync(&device_buffer, 5 * ndim * sizeof(int64_t), stream)); int64_t *device_a_strides, *device_b_strides, *device_out_strides, *device_a_shape, *device_b_shape; device_a_strides = device_buffer + ndim * 0; @@ -554,7 +559,8 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out host_buffer.insert(host_buffer.end(), a_shape.begin(), a_shape.end()); host_buffer.insert(host_buffer.end(), b_shape.begin(), b_shape.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + CUDA_CHECK( + cudaMemcpyAsync(device_buffer, host_buffer.data(), 5 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); const size_t num_elements = grad_output->NumElements(); @@ -616,7 +622,11 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr &out }, output_a, inputs...); } - cudaFreeAsync(device_buffer, stream); + CUDA_CHECK(cudaFreeAsync(device_buffer, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); } template std::shared_ptr UnaryForward(const std::shared_ptr &input, Func unary_fn) { diff --git a/infini_train/src/kernels/cuda/gather.cu b/infini_train/src/kernels/cuda/gather.cu index d318465c..e498e015 100644 --- a/infini_train/src/kernels/cuda/gather.cu +++ b/infini_train/src/kernels/cuda/gather.cu @@ -89,12 +89,9 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, int64_t *in_strides_dev = dev_buf + 1 * num_dims; int64_t *out_strides_dev = dev_buf + 2 * num_dims; - CUDA_CHECK( - cudaMemcpyAsync(out_dims_dev, idx_dims.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK( - cudaMemcpyAsync(in_strides_dev, in_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(out_strides_dev, out_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice, - stream)); + CUDA_CHECK(cudaMemcpy(out_dims_dev, idx_dims.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(in_strides_dev, in_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(out_strides_dev, out_strides.data(), num_dims * sizeof(int64_t), cudaMemcpyHostToDevice)); const int threads = 256; const int blocks = (total_elements + threads - 1) / threads; @@ -110,6 +107,12 @@ std::shared_ptr IndexGatherForward(const std::shared_ptr &input, "CUDA IndexGatherForward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return out; } @@ -198,11 +201,10 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ int64_t *in_strides_dev = out_dims_dev + n_out; int64_t *out_strides_dev = in_strides_dev + n_in_strides; - CUDA_CHECK(cudaMemcpyAsync(out_dims_dev, idx_dims.data(), n_out * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(in_strides_dev, in_strides.data(), n_in_strides * sizeof(int64_t), - cudaMemcpyHostToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(out_strides_dev, out_strides.data(), n_out_strides * sizeof(int64_t), - cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpy(out_dims_dev, idx_dims.data(), n_out * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(in_strides_dev, in_strides.data(), n_in_strides * sizeof(int64_t), cudaMemcpyHostToDevice)); + CUDA_CHECK( + cudaMemcpy(out_strides_dev, out_strides.data(), n_out_strides * sizeof(int64_t), cudaMemcpyHostToDevice)); const int threads = 256; const int blocks = (int)((total_elements + threads - 1) / threads); @@ -218,6 +220,10 @@ std::shared_ptr IndexGatherBackward(const std::shared_ptr &grad_ "CUDA IndexGatherBackward"); CUDA_CHECK(cudaFreeAsync(dev_buf, stream)); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); return grad_input; } diff --git a/infini_train/src/kernels/cuda/slice.cu b/infini_train/src/kernels/cuda/slice.cu index 29a8f1ae..c2ee80ba 100644 --- a/infini_train/src/kernels/cuda/slice.cu +++ b/infini_train/src/kernels/cuda/slice.cu @@ -73,21 +73,24 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - cudaMallocAsync(&new_dims_dev, - (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), - stream); + CUDA_CHECK(cudaMallocAsync( + &new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), + stream)); starts_dev = new_dims_dev + ends.size(); steps_dev = starts_dev + starts.size(); input_strides_dev = steps_dev + steps.size(); output_strides_dev = input_strides_dev + dims.size(); - cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); + CUDA_CHECK( + cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; @@ -103,6 +106,11 @@ std::shared_ptr SliceForward(const std::shared_ptr &input, const cudaFreeAsync(new_dims_dev, stream); + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); + return new_tensor; } @@ -167,21 +175,24 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output const auto &stream = dynamic_cast( infini_train::core::GetDeviceGuardImpl(device.type())->GetStream(device)) ->cuda_stream(); - cudaMallocAsync(&new_dims_dev, - (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), - stream); + CUDA_CHECK(cudaMallocAsync( + &new_dims_dev, (ends.size() + starts.size() + steps.size() + dims.size() + new_dims.size()) * sizeof(int64_t), + stream)); starts_dev = new_dims_dev + ends.size(); steps_dev = starts_dev + starts.size(); input_strides_dev = steps_dev + steps.size(); output_strides_dev = input_strides_dev + dims.size(); - cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), cudaMemcpyHostToDevice, - stream); + CUDA_CHECK( + cudaMemcpyAsync(new_dims_dev, new_dims.data(), ends.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(starts_dev, starts.data(), starts.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK( + cudaMemcpyAsync(steps_dev, steps.data(), steps.size() * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(input_strides_dev, src_strides.data(), dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(output_strides_dev, dst_strides.data(), new_dims.size() * sizeof(int64_t), + cudaMemcpyHostToDevice, stream)); int threads_per_block = 256; int num_blocks = (total_elements + threads_per_block - 1) / threads_per_block; @@ -195,7 +206,12 @@ std::shared_ptr SliceBackward(const std::shared_ptr &grad_output }, "CUDA SliceBackward"); - cudaFreeAsync(new_dims_dev, stream); + CUDA_CHECK(cudaFreeAsync(new_dims_dev, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); return grad_input; } diff --git a/infini_train/src/kernels/cuda/split.cu b/infini_train/src/kernels/cuda/split.cu index ec258976..89f04609 100644 --- a/infini_train/src/kernels/cuda/split.cu +++ b/infini_train/src/kernels/cuda/split.cu @@ -133,18 +133,19 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di void *device_ptr; const T **device_grad_output_ptrs; int64_t *device_H_outs; - cudaMallocAsync(&device_ptr, (sizeof(T *) + sizeof(int64_t)) * num_splits, stream); + CUDA_CHECK(cudaMallocAsync(&device_ptr, (sizeof(T *) + sizeof(int64_t)) * num_splits, stream)); device_grad_output_ptrs = (const T **)(device_ptr); device_H_outs = reinterpret_cast(device_grad_output_ptrs + num_splits); - cudaMemcpyAsync(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, - cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMemcpyAsync(device_grad_output_ptrs, host_grad_output_ptrs.data(), sizeof(T *) * num_splits, + cudaMemcpyHostToDevice, stream)); // init H_out for each split std::vector H_outs(num_splits); for (int i = 0; i < num_splits; ++i) { H_outs[i] = std::min(split_size, H_in - i * split_size); } - cudaMemcpyAsync(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice, stream); + CUDA_CHECK( + cudaMemcpyAsync(device_H_outs, H_outs.data(), sizeof(int64_t) * num_splits, cudaMemcpyHostToDevice, stream)); int64_t total_elements = N * H_in * W; int threads_per_block = 256; @@ -154,7 +155,12 @@ std::shared_ptr LaunchSplitBackward(const std::vector &input_di static_cast(grad_input->DataPtr()), N, H_in, W, split_size, num_splits, device_H_outs); - cudaFreeAsync(device_ptr, stream); + CUDA_CHECK(cudaFreeAsync(device_ptr, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); return grad_input; } diff --git a/infini_train/src/kernels/cuda/stack.cu b/infini_train/src/kernels/cuda/stack.cu index 56067cb8..fe70cc2f 100644 --- a/infini_train/src/kernels/cuda/stack.cu +++ b/infini_train/src/kernels/cuda/stack.cu @@ -67,14 +67,19 @@ std::shared_ptr StackForward(const std::vector> for (const auto &t : inputs) { host_input_ptrs.push_back(static_cast(t->DataPtr())); } const T **device_input_ptrs; - cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream); - cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, - stream); + CUDA_CHECK(cudaMallocAsync(&device_input_ptrs, sizeof(T *) * num_inputs, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_input_ptrs, host_input_ptrs.data(), sizeof(T *) * num_inputs, + cudaMemcpyHostToDevice, stream)); StackForwardKernel<<>>( device_input_ptrs, static_cast(output->DataPtr()), N, D, num_inputs); - cudaFreeAsync(device_input_ptrs, stream); + CUDA_CHECK(cudaFreeAsync(device_input_ptrs, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA StackForward"); @@ -136,13 +141,19 @@ std::vector> StackBackward(const std::vector &i for (auto &t : grads) { host_ptrs.push_back(static_cast(t->DataPtr())); } T **device_ptrs; - cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream); - cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, stream); + CUDA_CHECK(cudaMallocAsync(&device_ptrs, sizeof(T *) * num_inputs, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_ptrs, host_ptrs.data(), sizeof(T *) * num_inputs, cudaMemcpyHostToDevice, + stream)); StackBackwardKernel<<>>( static_cast(grad_output->DataPtr()), device_ptrs, N, D, num_inputs); - cudaFreeAsync(device_ptrs, stream); + CUDA_CHECK(cudaFreeAsync(device_ptrs, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); }, "CUDA StackBackward"); diff --git a/infini_train/src/kernels/cuda/transform.cu b/infini_train/src/kernels/cuda/transform.cu index 62d316b1..4e9630bd 100644 --- a/infini_train/src/kernels/cuda/transform.cu +++ b/infini_train/src/kernels/cuda/transform.cu @@ -252,7 +252,7 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i // Allocate device memory for dims and strides // TODO(zbl): avoid using cudaMalloc? int64_t *device_buffer; - cudaMallocAsync(&device_buffer, 3 * ndim * sizeof(int64_t), stream); + CUDA_CHECK(cudaMallocAsync(&device_buffer, 3 * ndim * sizeof(int64_t), stream)); int64_t *in_dims_dev = device_buffer; int64_t *in_strides_dev = device_buffer + ndim; @@ -263,7 +263,8 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i host_buffer.insert(host_buffer.end(), in_strides.begin(), in_strides.end()); host_buffer.insert(host_buffer.end(), out_strides.begin(), out_strides.end()); - cudaMemcpyAsync(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream); + CUDA_CHECK( + cudaMemcpyAsync(device_buffer, host_buffer.data(), 3 * ndim * sizeof(int64_t), cudaMemcpyHostToDevice, stream)); int threads_per_block = 256; int num_blocks = (num_elements + threads_per_block - 1) / threads_per_block; @@ -278,7 +279,12 @@ std::shared_ptr TransposeForward(const std::shared_ptr &input, i }, "CUDA TransposeForward"); - cudaFreeAsync(device_buffer, stream); + CUDA_CHECK(cudaFreeAsync(device_buffer, stream)); + + // NOTE(dcj): + // Synchronize the stream here to ensure all preceding H2D/D2H memcpy + // operations have completed before the host buffers go out of scope. + CUDA_CHECK(cudaStreamSynchronize(stream)); return output; } diff --git a/infini_train/src/nn/init.cc b/infini_train/src/nn/init.cc index 27f473c2..15144371 100644 --- a/infini_train/src/nn/init.cc +++ b/infini_train/src/nn/init.cc @@ -53,6 +53,7 @@ std::shared_ptr Normal(const std::shared_ptr &tensor, float mean impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -146,6 +147,7 @@ std::shared_ptr Uniform(const std::shared_ptr &tensor, float a, impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -164,6 +166,7 @@ std::shared_ptr Ones(const std::shared_ptr &tensor) { impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -182,6 +185,7 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(float), device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); return tensor; } @@ -191,6 +195,7 @@ std::shared_ptr Zeros(const std::shared_ptr &tensor) { std::vector buffer(num_elements); \ std::iota(buffer.begin(), buffer.end(), static_cast(start)); \ impl->MemcpyAsync(tensor->DataPtr(), buffer.data(), num_elements * sizeof(TYPE), kind, stream); \ + impl->SynchronizeStream(stream); \ break; \ } diff --git a/infini_train/src/tensor.cc b/infini_train/src/tensor.cc index 774e36ad..06c6354e 100644 --- a/infini_train/src/tensor.cc +++ b/infini_train/src/tensor.cc @@ -70,6 +70,7 @@ Tensor::Tensor(const float *data, const std::vector &dims, DataType dty impl->MemcpyAsync(buffer_->DataPtr(), data, buffer_->Size(), device.type() == Device::DeviceType::kCPU ? core::MemcpyKind::kD2D : core::MemcpyKind::kH2D, impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); } void Tensor::SetData(const Tensor &tensor, size_t offset, bool preserve_data) { @@ -164,6 +165,7 @@ Tensor Tensor::To(Device device) { auto impl = core::GetDeviceGuardImpl(buffer_device.type()); impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kD2H, impl->GetStream(buffer_device)); + impl->SynchronizeStream(impl->GetStream(buffer_device)); } else if (buffer_device.type() == Device::DeviceType::kCPU) { new_tensor = Tensor(dims_, dtype_, device); @@ -172,6 +174,7 @@ Tensor Tensor::To(Device device) { auto *impl = core::GetDeviceGuardImpl(device.type()); impl->MemcpyAsync(new_tensor.DataPtr(), DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, impl->GetStream(device)); + impl->SynchronizeStream(impl->GetStream(device)); } else { new_tensor = Tensor(dims_, dtype_, device); // P2P @@ -182,6 +185,7 @@ Tensor Tensor::To(Device device) { auto *impl = core::GetDeviceGuardImpl(buffer_device.type()); impl->MemcpyAsync(new_tensor.DataPtr(), cpu_tensor.DataPtr(), SizeInBytes(), core::MemcpyKind::kH2D, impl->GetStream(buffer_device)); + impl->SynchronizeStream(impl->GetStream(buffer_device)); } if (grad_) { @@ -231,16 +235,19 @@ void Tensor::CopyFrom(const Tensor &src) { core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2D, impl->GetStream(dst_dev)); + impl->SynchronizeStream(impl->GetStream(dst_dev)); } else if (dst_dev.type() == Device::DeviceType::kCPU) { // D2H core::DeviceGuard guard(src_dev); auto *impl = core::GetDeviceGuardImpl(src_dev.type()); impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kD2H, impl->GetStream(src_dev)); + impl->SynchronizeStream(impl->GetStream(src_dev)); } else if (src_dev.type() == Device::DeviceType::kCPU) { // H2D core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); impl->MemcpyAsync(DataPtr(), src.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + impl->SynchronizeStream(impl->GetStream(dst_dev)); } else { // TODO(dcj): maybe support p2p api later // P2P @@ -251,6 +258,7 @@ void Tensor::CopyFrom(const Tensor &src) { core::DeviceGuard guard(dst_dev); auto *impl = core::GetDeviceGuardImpl(dst_dev.type()); impl->MemcpyAsync(DataPtr(), cpu_tensor.DataPtr(), nbytes, core::MemcpyKind::kH2D, impl->GetStream(dst_dev)); + impl->SynchronizeStream(impl->GetStream(dst_dev)); } }