diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index d84eb54d2b9..f7e674374a3 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -12,6 +12,8 @@ #include +#include + #include #include #include @@ -465,24 +467,19 @@ inline ivec4 make_whcn_ivec4(const std::vector& arr) { } /* - * Wrapper around std::accumulate that accumulates values of a container of - * integral types into int64_t. Taken from `multiply_integers` in - * + * Computes the product of integral values in a container, accumulating into + * int64_t with overflow checking. Throws on overflow. */ template < typename C, std::enable_if_t::value, int> = 0> inline int64_t multiply_integers(const C& container) { - return std::accumulate( - container.begin(), - container.end(), - static_cast(1), - std::multiplies<>()); + return multiply_integers(container.begin(), container.end()); } /* - * Product of integer elements referred to by iterators; accumulates into the - * int64_t datatype. Taken from `multiply_integers` in + * Computes the product of integral values referred to by iterators, + * accumulating into int64_t with overflow checking. Throws on overflow. */ template < typename Iter, @@ -491,11 +488,13 @@ template < typename std::iterator_traits::value_type>::value, int> = 0> inline int64_t multiply_integers(Iter begin, Iter end) { - // std::accumulate infers return type from `init` type, so if the `init` type - // is not large enough to hold the result, computation can overflow. We use - // `int64_t` here to avoid this. - return std::accumulate( - begin, end, static_cast(1), std::multiplies<>()); + int64_t result = 1; + for (Iter it = begin; it != end; ++it) { + VK_CHECK_COND( + !c10::mul_overflows(result, static_cast(*it), &result), + "Integer overflow in multiply_integers"); + } + return result; } class WorkgroupSize final {