-
Notifications
You must be signed in to change notification settings - Fork 922
Fix #18654: Add safe_numel() #18698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix #18654: Add safe_numel() #18698
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||||||||
|
|
||||||||||||
| #include <algorithm> | ||||||||||||
| #include <cstdint> | ||||||||||||
| #include <limits> | ||||||||||||
|
|
||||||||||||
| #include <c10/util/irange.h> | ||||||||||||
|
|
||||||||||||
|
|
@@ -68,6 +69,25 @@ TensorImpl::TensorImpl( | |||||||||||
| ET_CHECK_MSG(dim_ >= 0, "Dimension must be non-negative, got %zd", dim_); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| Result<ssize_t> TensorImpl::safe_numel() const { | ||||||||||||
| ssize_t numel = 1; | ||||||||||||
| for (const auto i : c10::irange(dim_)) { | ||||||||||||
| ET_CHECK_OR_RETURN_ERROR( | ||||||||||||
| sizes_[i] >= 0, | ||||||||||||
| InvalidArgument, | ||||||||||||
| "Size must be non-negative, got %zd at dimension %zd", | ||||||||||||
| static_cast<ssize_t>(sizes_[i]), | ||||||||||||
| i); | ||||||||||||
| ET_CHECK_OR_RETURN_ERROR( | ||||||||||||
| sizes_[i] == 0 || | ||||||||||||
| numel <= std::numeric_limits<ssize_t>::max() / sizes_[i], | ||||||||||||
| InvalidArgument, | ||||||||||||
| "Tensor numel overflows ssize_t"); | ||||||||||||
|
||||||||||||
| "Tensor numel overflows ssize_t"); | |
| "Tensor numel overflows ssize_t at dimension %zd: size=%zd, partial_numel=%zd", | |
| i, | |
| static_cast<ssize_t>(sizes_[i]), | |
| numel); |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||||||||
|
|
||||||||||||
| #include <executorch/runtime/core/array_ref.h> | ||||||||||||
| #include <executorch/runtime/core/error.h> | ||||||||||||
| #include <executorch/runtime/core/result.h> | ||||||||||||
| #include <executorch/runtime/core/portable_type/device.h> | ||||||||||||
| #include <executorch/runtime/core/portable_type/scalar_type.h> | ||||||||||||
| #include <executorch/runtime/core/tensor_shape_dynamism.h> | ||||||||||||
|
|
@@ -149,6 +150,12 @@ class TensorImpl { | |||||||||||
| return numel_; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| /** | ||||||||||||
| * Returns the number of elements in the tensor, or an error if the result | ||||||||||||
| * would overflow ssize_t. | ||||||||||||
|
Comment on lines
+154
to
+155
|
||||||||||||
| * Returns the number of elements in the tensor, or an error if the result | |
| * would overflow ssize_t. | |
| * Returns the number of elements in the tensor, or an error if any | |
| * dimension size is negative or if computing the result would overflow | |
| * ssize_t. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| #include <executorch/runtime/core/portable_type/tensor_impl.h> | ||
|
|
||
| #include <gtest/gtest.h> | ||
| #include <limits> | ||
| #include <random> | ||
|
|
||
| #include <executorch/runtime/core/exec_aten/util/tensor_util.h> | ||
|
|
@@ -40,6 +41,33 @@ class TensorImplTest : public ::testing::Test { | |
| } | ||
| }; | ||
|
|
||
| TEST_F(TensorImplTest, SafeNumelReturnsCorrectValue) { | ||
| SizesType sizes[2] = {3, 2}; | ||
| TensorImpl t(ScalarType::Float, 2, sizes); | ||
| auto result = t.safe_numel(); | ||
| ASSERT_TRUE(result.ok()); | ||
| EXPECT_EQ(*result, 6); | ||
| } | ||
|
|
||
| TEST_F(TensorImplTest, SafeNumelScalar) { | ||
| TensorImpl t(ScalarType::Float, 0, nullptr); | ||
| auto result = t.safe_numel(); | ||
| ASSERT_TRUE(result.ok()); | ||
| EXPECT_EQ(*result, 1); | ||
| } | ||
|
|
||
| TEST_F(TensorImplTest, SafeNumelOverflowReturnsError) { | ||
| // Three large dimensions whose product overflows ssize_t on any platform: | ||
| // On 64-bit: INT32_MAX^2 * 3 > INT64_MAX; on 32-bit: INT32_MAX^2 > INT32_MAX. | ||
| SizesType sizes[3] = { | ||
| std::numeric_limits<SizesType>::max(), | ||
| std::numeric_limits<SizesType>::max(), | ||
| 3}; | ||
|
Comment on lines
+60
to
+65
|
||
| TensorImpl t(ScalarType::Float, 3, sizes); | ||
| auto result = t.safe_numel(); | ||
| EXPECT_FALSE(result.ok()); | ||
| } | ||
|
|
||
| TEST_F(TensorImplTest, TestCtorAndGetters) { | ||
| SizesType sizes[2] = {3, 2}; | ||
| DimOrderType dim_order[2] = {0, 1}; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safe_numel() duplicates much of compute_numel()’s logic (iteration + non-negative size checks). Consider factoring the shared computation into a single helper (e.g., a checked-multiply routine) so future changes to size validation/numel semantics don’t have to be kept in sync in multiple places.