From 19da758cba1a7b6f0d3cd8f3618fb7010f3d221f Mon Sep 17 00:00:00 2001 From: Nischal1729 Date: Mon, 15 Sep 2025 10:49:14 +0530 Subject: [PATCH] fix: for balanced kmeans use grid.x for adjust_centers to avoid grid.y overflow for large n_clusters --- cpp/src/cluster/detail/kmeans_balanced.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index d48e99126a..01753f9ea9 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -511,7 +511,7 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL IdxT* count, MappingOpT mapping_op) { - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); + IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.x); if (l >= n_clusters) return; auto csize = static_cast(cluster_sizes[l]); // skip big clusters @@ -616,7 +616,7 @@ auto adjust_centers(MathT* centers, constexpr uint32_t kBlockDimY = 4; const dim3 block_dim(raft::WarpSize, kBlockDimY, 1); - const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); + const dim3 grid_dim(raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1, 1); rmm::device_scalar update_count(0, stream, device_memory); adjust_centers_kernel<<>>(centers, n_clusters,