Skip to content

Commit 1deb127

Browse files
committed
feat: Add SVE kernels for TopKV
Change-Id: I7a0c7bd1154b9cb7f35c7fd1c3b8ad54698f8799 Signed-off-by: Pablo Marquez Tello <pablo.tello@arm.com>
1 parent 9c9151c commit 1deb127

11 files changed

Lines changed: 554 additions & 20 deletions

File tree

filelist.json

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2453,27 +2453,31 @@
24532453
]
24542454
}
24552455
},
2456+
24562457
"TopKV": {
2457-
"files": {
2458-
"common": [
2459-
"src/cpu/kernels/CpuTopKVKernel.cpp",
2460-
"src/cpu/operators/CpuTopKV.cpp",
2461-
"src/runtime/NEON/functions/NETopKV.cpp"
2462-
],
2463-
"neon": {
2464-
"fp16": [ "src/cpu/kernels/topkv/generic/neon/fp16.cpp" ],
2465-
"fp32": [ "src/cpu/kernels/topkv/generic/neon/fp32.cpp" ],
2466-
"integer":["src/cpu/kernels/topkv/generic/neon/integer.cpp"],
2467-
"qasymm8": [
2468-
"src/cpu/kernels/topkv/generic/neon/qasymm8.cpp"
2469-
],
2470-
"qasymm8_signed": [
2471-
"src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp"
2472-
]
2473-
}
2458+
"files": {
2459+
"common": [
2460+
"src/cpu/kernels/CpuTopKVKernel.cpp",
2461+
"src/cpu/operators/CpuTopKV.cpp",
2462+
"src/runtime/NEON/functions/NETopKV.cpp"
2463+
],
2464+
"neon": {
2465+
"fp16": [ "src/cpu/kernels/topkv/generic/neon/fp16.cpp" ],
2466+
"fp32": [ "src/cpu/kernels/topkv/generic/neon/fp32.cpp" ],
2467+
"integer": [ "src/cpu/kernels/topkv/generic/neon/integer.cpp" ],
2468+
"qasymm8": [ "src/cpu/kernels/topkv/generic/neon/qasymm8.cpp" ],
2469+
"qasymm8_signed": [ "src/cpu/kernels/topkv/generic/neon/qasymm8_signed.cpp" ]
2470+
},
2471+
"sve": {
2472+
"fp32": [ "src/cpu/kernels/topkv/generic/sve/fp32.cpp" ],
2473+
"fp16": [ "src/cpu/kernels/topkv/generic/sve/fp16.cpp" ],
2474+
"integer": [ "src/cpu/kernels/topkv/generic/sve/integer.cpp" ],
2475+
"qasymm8": [ "src/cpu/kernels/topkv/generic/sve/qasymm8.cpp" ],
2476+
"qasymm8_signed": [ "src/cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp" ]
2477+
}
2478+
}
2479+
},
24742480

2475-
}
2476-
},
24772481
"Transpose": {
24782482
"files": {
24792483
"common": [

src/BUILD.bazel

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,12 @@ filegroup(
395395
"cpu/kernels/scale/sve/qasymm8.cpp",
396396
"cpu/kernels/scale/sve/qasymm8_signed.cpp",
397397
"cpu/kernels/softmax/generic/sve/impl.cpp",
398-
"cpu/kernels/softmax/generic/sve/impl_bf16.cpp"] +
398+
"cpu/kernels/softmax/generic/sve/impl_bf16.cpp",
399+
"cpu/kernels/topkv/generic/sve/fp16.cpp",
400+
"cpu/kernels/topkv/generic/sve/fp32.cpp",
401+
"cpu/kernels/topkv/generic/sve/integer.cpp",
402+
"cpu/kernels/topkv/generic/sve/qasymm8.cpp",
403+
"cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp"] +
399404
glob(["**/*.h",
400405
"**/*.hpp",
401406
"**/*.inl"]),

src/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ target_sources(
365365
cpu/kernels/scale/sve/qasymm8_signed.cpp
366366
cpu/kernels/softmax/generic/sve/impl.cpp
367367
cpu/kernels/softmax/generic/sve/impl_bf16.cpp
368+
cpu/kernels/topkv/generic/sve/fp16.cpp
369+
cpu/kernels/topkv/generic/sve/fp32.cpp
370+
cpu/kernels/topkv/generic/sve/integer.cpp
371+
cpu/kernels/topkv/generic/sve/qasymm8.cpp
372+
cpu/kernels/topkv/generic/sve/qasymm8_signed.cpp
368373
)
369374

370375
target_sources(

src/cpu/kernels/CpuTopKVKernel.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,42 @@ namespace
4343
{
4444

4545
static const std::vector<CpuTopKVKernel::TopKVKernel> available_kernels = {
46+
47+
{"sve_fp16_topkv",
48+
[](const CpuTopKVKernelDataTypeISASelectorData &data)
49+
{ return (data.dt == DataType::F16) && data.isa.fp16 && data.isa.sve; },
50+
REGISTER_FP16_SVE(arm_compute::cpu::topkv_fp16_sve)},
51+
52+
{"sve_fp32_topkv",
53+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F32) && data.isa.sve; },
54+
REGISTER_FP32_SVE(arm_compute::cpu::topkv_fp32_sve)},
55+
56+
{"sve_qasymm8_topkv",
57+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8) && data.isa.sve; },
58+
REGISTER_QASYMM8_SVE(arm_compute::cpu::topkv_qasymm8_sve)},
59+
60+
{"sve_qasymm8_signed_topkv",
61+
[](const CpuTopKVKernelDataTypeISASelectorData &data)
62+
{ return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve; },
63+
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::topkv_qasymm8_signed_sve)},
64+
65+
{"sve_s32_topkv",
66+
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::S32) && data.isa.sve; },
67+
REGISTER_INTEGER_SVE(arm_compute::cpu::topkv_s32_sve)},
68+
4669
{"neon_s32_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::S32); },
4770
REGISTER_INTEGER_NEON(arm_compute::cpu::topkv_s32_neon)},
71+
4872
{"neon_fp32_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F32); },
4973
REGISTER_FP32_NEON(arm_compute::cpu::topkv_fp32_neon)},
74+
5075
{"neon_fp16_topkv",
5176
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::F16) && data.isa.fp16; },
5277
REGISTER_FP16_NEON(arm_compute::cpu::topkv_fp16_neon)},
78+
5379
{"neon_qu8_topkv", [](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8); },
5480
REGISTER_QASYMM8_NEON(arm_compute::cpu::topkv_qasymm8_neon)},
81+
5582
{"neon_qs8_topkv",
5683
[](const CpuTopKVKernelDataTypeISASelectorData &data) { return (data.dt == DataType::QASYMM8_SIGNED); },
5784
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::topkv_qasymm8_signed_neon)}};
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#if defined(__ARM_FEATURE_SVE)
25+
26+
#include "src/cpu/kernels/topkv/generic/sve/impl.h"
27+
28+
#include <arm_sve.h>
29+
30+
namespace arm_compute
31+
{
32+
namespace cpu
33+
{
34+
namespace detail
35+
{
36+
37+
template <>
38+
inline uint32_t block_width<float16_t>(uint32_t remaining)
39+
{
40+
const uint32_t bw = static_cast<uint32_t>(svcnth());
41+
return (remaining < bw) ? remaining : bw;
42+
}
43+
44+
template <>
45+
inline uint32_t count_gt_block<float16_t>(const float16_t *ptr, float16_t thr, uint32_t remaining)
46+
{
47+
const uint32_t bw = block_width<float16_t>(remaining);
48+
svbool_t pg = svwhilelt_b16((uint64_t)0, (uint64_t)bw);
49+
svfloat16_t v = svld1_f16(pg, ptr);
50+
svbool_t gt = svcmpgt_n_f16(pg, v, thr);
51+
return static_cast<uint32_t>(svcntp_b16(svptrue_b16(), gt));
52+
}
53+
54+
} // namespace detail
55+
56+
void topkv_fp16_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
57+
{
58+
detail::topkv_sve_wrapper<float16_t>(predictions, targets, out, k, win);
59+
}
60+
61+
template void
62+
detail::topkv_sve_wrapper<float16_t>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &);
63+
64+
} // namespace cpu
65+
} // namespace arm_compute
66+
67+
#endif // defined(__ARM_FEATURE_SVE)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#if defined(__ARM_FEATURE_SVE)
25+
26+
#include "src/cpu/kernels/topkv/generic/sve/impl.h"
27+
28+
#include <arm_sve.h>
29+
#include <cstdint>
30+
31+
namespace arm_compute
32+
{
33+
namespace cpu
34+
{
35+
namespace detail
36+
{
37+
38+
// float32 block width (32-bit lanes)
39+
template <>
40+
inline uint32_t block_width<float>(uint32_t remaining)
41+
{
42+
const uint32_t bw = static_cast<uint32_t>(svcntw());
43+
return (remaining < bw) ? remaining : bw;
44+
}
45+
46+
// Count lanes where v > thr for a block of float32
47+
template <>
48+
inline uint32_t count_gt_block<float>(const float *ptr, float thr, uint32_t remaining)
49+
{
50+
const uint32_t bw = block_width<float>(remaining);
51+
52+
const svbool_t pg = svwhilelt_b32(static_cast<uint64_t>(0), static_cast<uint64_t>(bw));
53+
const svfloat32_t v = svld1_f32(pg, ptr);
54+
const svbool_t gt = svcmpgt_n_f32(pg, v, thr);
55+
56+
return static_cast<uint32_t>(svcntp_b32(svptrue_b32(), gt));
57+
}
58+
59+
} // namespace detail
60+
61+
// Exported SVE kernel for FP32 TopKV
62+
void topkv_fp32_sve(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &win)
63+
{
64+
detail::topkv_sve_wrapper<float>(predictions, targets, out, k, win);
65+
}
66+
67+
// Explicit instantiation of the generic wrapper
68+
template void detail::topkv_sve_wrapper<float>(const ITensor *, const ITensor *, ITensor *, uint32_t, const Window &);
69+
70+
} // namespace cpu
71+
} // namespace arm_compute
72+
73+
#endif // __ARM_FEATURE_SVE
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Copyright (c) 2026 Arm Limited.
3+
*
4+
* SPDX-License-Identifier: MIT
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to
8+
* deal in the Software without restriction, including without limitation the
9+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10+
* sell copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in all
14+
* copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
* SOFTWARE.
23+
*/
24+
#ifndef ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H
25+
#define ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H
26+
27+
#include "arm_compute/core/Coordinates.h"
28+
#include "arm_compute/core/Error.h"
29+
#include "arm_compute/core/Helpers.h"
30+
#include "arm_compute/core/ITensor.h"
31+
#include "arm_compute/core/Window.h"
32+
33+
#include <cstdint>
34+
#include <cstring>
35+
36+
namespace arm_compute
37+
{
38+
namespace cpu
39+
{
40+
namespace detail
41+
{
42+
43+
/*
44+
* Type-specific hooks
45+
*
46+
* - count_gt_block<Scalar>(ptr, thr, remaining)
47+
* Count how many elements in the next vector-block are > thr.
48+
* It should process at most 'remaining' elements and return the count.
49+
*
50+
* - block_width<Scalar>(remaining)
51+
* Return how many elements the specialization processes in one block
52+
* (usually the full vector width clamped to 'remaining').
53+
*
54+
* Both must be implemented in the cpp that contains the SVE intrinsics
55+
* (e.g., qasymm8.cpp, qasymm8_signed.cpp, fp16.cpp, fp32.cpp, integer.cpp).
56+
*/
57+
58+
// Count lanes > thr in a single block (type-specific, implemented in .cpp)
59+
template <typename Scalar>
60+
uint32_t count_gt_block(const Scalar *ptr, Scalar thr, uint32_t remaining);
61+
62+
// Number of elements processed by one block for this Scalar (type-specific)
63+
template <typename Scalar>
64+
uint32_t block_width(uint32_t remaining);
65+
66+
// ----------------------------------------------------------------------------
67+
// Generic wrapper (type-agnostic) - uses the above hooks.
68+
// Iteration semantics:
69+
// - predictions is N x C
70+
// - window iterates across output elements (classes) => id.x() == class index c
71+
// - for each class c, targets[c] gives the sample index t
72+
// - scan across N samples and compute rank (#samples with value > predictions[t])
73+
// ----------------------------------------------------------------------------
74+
template <typename Scalar>
75+
inline void
76+
topkv_sve_wrapper(const ITensor *predictions, const ITensor *targets, ITensor *out, uint32_t k, const Window &window)
77+
{
78+
ARM_COMPUTE_ERROR_ON_NULLPTR(predictions, targets, out);
79+
ARM_COMPUTE_ERROR_ON(k == 0);
80+
81+
const ITensorInfo *pred_info = predictions->info();
82+
const uint32_t N = pred_info->dimension(0); // samples
83+
const uint32_t C = pred_info->dimension(1); // classes
84+
85+
Iterator tgt_it(targets, window);
86+
Iterator out_it(out, window);
87+
88+
execute_window_loop(
89+
window,
90+
[&](const Coordinates &id)
91+
{
92+
const uint32_t c = static_cast<uint32_t>(id.x()); // class index
93+
ARM_COMPUTE_ERROR_ON(c >= C);
94+
95+
uint32_t t = 0;
96+
std::memcpy(&t, tgt_it.ptr(), sizeof(uint32_t)); // target sample idx for this class
97+
ARM_COMPUTE_ERROR_ON(t >= N);
98+
99+
// column pointer: for class c, samples are along dim0 starting at (0,c)
100+
const Scalar *col_ptr = reinterpret_cast<const Scalar *>(predictions->ptr_to_element(Coordinates(0, c)));
101+
ARM_COMPUTE_ERROR_ON(col_ptr == nullptr);
102+
103+
const Scalar thr = col_ptr[t];
104+
105+
uint32_t rank = 0;
106+
uint32_t idx = 0;
107+
while (idx < N)
108+
{
109+
const uint32_t remaining = N - idx;
110+
const uint32_t bw = block_width<Scalar>(remaining);
111+
// count_gt_block is expected to look only within [ptr, ptr + bw)
112+
rank += count_gt_block<Scalar>(col_ptr + idx, thr, remaining);
113+
114+
if (rank >= k)
115+
{
116+
break;
117+
}
118+
119+
idx += bw;
120+
}
121+
122+
*reinterpret_cast<uint8_t *>(out_it.ptr()) = static_cast<uint8_t>(rank < k);
123+
},
124+
tgt_it, out_it);
125+
}
126+
127+
} // namespace detail
128+
} // namespace cpu
129+
} // namespace arm_compute
130+
131+
#endif // ACL_SRC_CPU_KERNELS_TOPKV_GENERIC_SVE_IMPL_H

0 commit comments

Comments
 (0)