@@ -81,13 +81,16 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
8181 input_strides_dev = steps_dev + steps.size ();
8282 output_strides_dev = input_strides_dev + dims.size ();
8383
84- CUDA_CHECK (cudaMemcpy (new_dims_dev, new_dims.data (), ends.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
85- CUDA_CHECK (cudaMemcpy (starts_dev, starts.data (), starts.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
86- CUDA_CHECK (cudaMemcpy (steps_dev, steps.data (), steps.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
8784 CUDA_CHECK (
88- cudaMemcpy (input_strides_dev, src_strides .data (), dims .size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
85+ cudaMemcpyAsync (new_dims_dev, new_dims .data (), ends .size () * sizeof (int64_t ), cudaMemcpyHostToDevice, stream ));
8986 CUDA_CHECK (
90- cudaMemcpy (output_strides_dev, dst_strides.data (), new_dims.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
87+ cudaMemcpyAsync (starts_dev, starts.data (), starts.size () * sizeof (int64_t ), cudaMemcpyHostToDevice, stream));
88+ CUDA_CHECK (
89+ cudaMemcpyAsync (steps_dev, steps.data (), steps.size () * sizeof (int64_t ), cudaMemcpyHostToDevice, stream));
90+ CUDA_CHECK (cudaMemcpyAsync (input_strides_dev, src_strides.data (), dims.size () * sizeof (int64_t ),
91+ cudaMemcpyHostToDevice, stream));
92+ CUDA_CHECK (cudaMemcpyAsync (output_strides_dev, dst_strides.data (), new_dims.size () * sizeof (int64_t ),
93+ cudaMemcpyHostToDevice, stream));
9194
9295 int threads_per_block = 256 ;
9396 int num_blocks = (total_elements + threads_per_block - 1 ) / threads_per_block;
@@ -103,6 +106,11 @@ std::shared_ptr<Tensor> SliceForward(const std::shared_ptr<Tensor> &input, const
103106
104107 cudaFreeAsync (new_dims_dev, stream);
105108
109+ // NOTE(dcj):
110+ // Synchronize the stream here to ensure all preceding H2D/D2H memcpy
111+ // operations have completed before the host buffers go out of scope.
112+ CUDA_CHECK (cudaStreamSynchronize (stream));
113+
106114 return new_tensor;
107115}
108116
@@ -175,13 +183,16 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
175183 input_strides_dev = steps_dev + steps.size ();
176184 output_strides_dev = input_strides_dev + dims.size ();
177185
178- CUDA_CHECK (cudaMemcpy (new_dims_dev, new_dims.data (), ends.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
179- CUDA_CHECK (cudaMemcpy (starts_dev, starts.data (), starts.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
180- CUDA_CHECK (cudaMemcpy (steps_dev, steps.data (), steps.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
181186 CUDA_CHECK (
182- cudaMemcpy (input_strides_dev, src_strides .data (), dims .size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
187+ cudaMemcpyAsync (new_dims_dev, new_dims .data (), ends .size () * sizeof (int64_t ), cudaMemcpyHostToDevice, stream ));
183188 CUDA_CHECK (
184- cudaMemcpy (output_strides_dev, dst_strides.data (), new_dims.size () * sizeof (int64_t ), cudaMemcpyHostToDevice));
189+ cudaMemcpyAsync (starts_dev, starts.data (), starts.size () * sizeof (int64_t ), cudaMemcpyHostToDevice, stream));
190+ CUDA_CHECK (
191+ cudaMemcpyAsync (steps_dev, steps.data (), steps.size () * sizeof (int64_t ), cudaMemcpyHostToDevice, stream));
192+ CUDA_CHECK (cudaMemcpyAsync (input_strides_dev, src_strides.data (), dims.size () * sizeof (int64_t ),
193+ cudaMemcpyHostToDevice, stream));
194+ CUDA_CHECK (cudaMemcpyAsync (output_strides_dev, dst_strides.data (), new_dims.size () * sizeof (int64_t ),
195+ cudaMemcpyHostToDevice, stream));
185196
186197 int threads_per_block = 256 ;
187198 int num_blocks = (total_elements + threads_per_block - 1 ) / threads_per_block;
@@ -197,6 +208,11 @@ std::shared_ptr<Tensor> SliceBackward(const std::shared_ptr<Tensor> &grad_output
197208
198209 CUDA_CHECK (cudaFreeAsync (new_dims_dev, stream));
199210
211+ // NOTE(dcj):
212+ // Synchronize the stream here to ensure all preceding H2D/D2H memcpy
213+ // operations have completed before the host buffers go out of scope.
214+ CUDA_CHECK (cudaStreamSynchronize (stream));
215+
200216 return grad_input;
201217}
202218} // namespace infini_train::kernels::cuda
0 commit comments