diff --git a/runtime/core/portable_type/tensor_impl.cpp b/runtime/core/portable_type/tensor_impl.cpp index 17243fca0fd..df670f9b32f 100644 --- a/runtime/core/portable_type/tensor_impl.cpp +++ b/runtime/core/portable_type/tensor_impl.cpp @@ -10,6 +10,7 @@ #include #include +#include #include @@ -68,6 +69,25 @@ TensorImpl::TensorImpl( ET_CHECK_MSG(dim_ >= 0, "Dimension must be non-negative, got %zd", dim_); } +Result TensorImpl::safe_numel() const { + ssize_t numel = 1; + for (const auto i : c10::irange(dim_)) { + ET_CHECK_OR_RETURN_ERROR( + sizes_[i] >= 0, + InvalidArgument, + "Size must be non-negative, got %zd at dimension %zd", + static_cast(sizes_[i]), + i); + ET_CHECK_OR_RETURN_ERROR( + sizes_[i] == 0 || + numel <= std::numeric_limits::max() / sizes_[i], + InvalidArgument, + "Tensor numel overflows ssize_t"); + numel *= sizes_[i]; + } + return numel; +} + size_t TensorImpl::nbytes() const { return numel_ * elementSize(type_); } diff --git a/runtime/core/portable_type/tensor_impl.h b/runtime/core/portable_type/tensor_impl.h index ea2cde5aeb0..fed08baa3b9 100644 --- a/runtime/core/portable_type/tensor_impl.h +++ b/runtime/core/portable_type/tensor_impl.h @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -149,6 +150,12 @@ class TensorImpl { return numel_; } + /** + * Returns the number of elements in the tensor, or an error if the result + * would overflow ssize_t. + */ + Result safe_numel() const; + /// Returns the type of the elements in the tensor (int32, float, bool, etc). ScalarType scalar_type() const { return type_; diff --git a/runtime/core/portable_type/test/tensor_impl_test.cpp b/runtime/core/portable_type/test/tensor_impl_test.cpp index 7d045da5b3d..c3dcd615347 100644 --- a/runtime/core/portable_type/test/tensor_impl_test.cpp +++ b/runtime/core/portable_type/test/tensor_impl_test.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -40,6 +41,33 @@ class TensorImplTest : public ::testing::Test { } }; +TEST_F(TensorImplTest, SafeNumelReturnsCorrectValue) { + SizesType sizes[2] = {3, 2}; + TensorImpl t(ScalarType::Float, 2, sizes); + auto result = t.safe_numel(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(*result, 6); +} + +TEST_F(TensorImplTest, SafeNumelScalar) { + TensorImpl t(ScalarType::Float, 0, nullptr); + auto result = t.safe_numel(); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(*result, 1); +} + +TEST_F(TensorImplTest, SafeNumelOverflowReturnsError) { + // Three large dimensions whose product overflows ssize_t on any platform: + // On 64-bit: INT32_MAX^2 * 3 > INT64_MAX; on 32-bit: INT32_MAX^2 > INT32_MAX. + SizesType sizes[3] = { + std::numeric_limits::max(), + std::numeric_limits::max(), + 3}; + TensorImpl t(ScalarType::Float, 3, sizes); + auto result = t.safe_numel(); + EXPECT_FALSE(result.ok()); +} + TEST_F(TensorImplTest, TestCtorAndGetters) { SizesType sizes[2] = {3, 2}; DimOrderType dim_order[2] = {0, 1};