diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index d48e99126a..01753f9ea9 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -511,7 +511,7 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL IdxT* count, MappingOpT mapping_op) { - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); + IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.x); if (l >= n_clusters) return; auto csize = static_cast(cluster_sizes[l]); // skip big clusters @@ -616,7 +616,7 @@ auto adjust_centers(MathT* centers, constexpr uint32_t kBlockDimY = 4; const dim3 block_dim(raft::WarpSize, kBlockDimY, 1); - const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); + const dim3 grid_dim(raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1, 1); rmm::device_scalar update_count(0, stream, device_memory); adjust_centers_kernel<<>>(centers, n_clusters,