fix: correct negative axis handling in roll function#2878
Open
f14XuanLv wants to merge 1 commit intoxtensor-stack:masterfrom
Open
fix: correct negative axis handling in roll function#2878f14XuanLv wants to merge 1 commit intoxtensor-stack:masterfrom
f14XuanLv wants to merge 1 commit intoxtensor-stack:masterfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Checklist
Description
This PR fixes a bug in
xt::roll(e, shift, axis)where negative axis indices (e.g.,-1for the last axis) were incorrectly rejected.The Bug
include/xtensor/misc/xmanipulation.hppThe original code converted
axistosize_tbefore normalization:This caused valid negative indices like
-1to incorrectly trigger the bounds check exception.The Fix
include/xtensor/misc/xmanipulation.hppauto cpy = empty_like(e); const auto& shape = cpy. shape(); - std::size_t saxis = static_cast<std::size_t>(axis); - if (axis < 0) - { - axis += std::ptrdiff_t(cpy. dimension()); - } + const auto dim = cpy.dimension(); - if (saxis >= cpy.dimension() || axis < 0) + if (axis < -static_cast<std::ptrdiff_t>(dim) || axis >= static_cast<std::ptrdiff_t>(dim)) { - XTENSOR_THROW(std::runtime_error, "axis is no within shape dimension."); + XTENSOR_THROW(std::runtime_error, "axis is not within shape dimension."); } + std::size_t saxis = normalize_axis(dim, axis); + const auto axis_dim = static_cast<std::ptrdiff_t>(shape[saxis]);normalize_axis()for consistency with other functions (swapaxes,moveaxis, etc.)"axis is no within"→"axis is not within"Tests Added
test/test_xmanipulation.cppxarray<double> expected8 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}}; ASSERT_EQ(expected8, xt::roll(e2, -2, /*axis*/ 2)); + // Boundary error cases + EXPECT_THROW(xt::roll(e2, 1, /*axis*/ 3), std::runtime_error); + EXPECT_THROW(xt::roll(e2, 1, /*axis*/ -4), std::runtime_error); + + // Negative axis indices + xarray<double> expected9 = {{{3, 1, 2}}, {{6, 4, 5}}, {{9, 7, 8}}}; + ASSERT_EQ(expected9, xt::roll(e2, -2, /*axis*/ -1)); + + xarray<double> expected10 = {{{1, 2, 3}}, {{4, 5, 6}}, {{7, 8, 9}}}; + ASSERT_EQ(expected10, xt::roll(e2, -2, /*axis*/ -2)); + + xarray<double> expected11 = {{{4, 5, 6}}, {{7, 8, 9}}, {{1, 2, 3}}}; + ASSERT_EQ(expected11, xt::roll(e2, 2, /*axis*/ -3)); }Note: This bug has existed since #1823 (2019).