From 3ff534b3e860ec9409bcb2a800d38b07931cae15 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Dec 2025 11:58:00 -0600 Subject: [PATCH 1/4] Add test cases --- src/shape_transform_descriptor.cpp | 1 - test/fuse_reduce.cpp | 53 +++++++++++++++++++++++++++++ test/shape_transform_descriptor.cpp | 16 +++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index aac7170cb4d..2743dcd2a01 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -330,7 +330,6 @@ struct rebase_ambiguity_resolver // Returns the axes mapping that can be used for rebase auto resolve() { - std::vector> subs_to_insert; { axes_map_t axes_map = group_axes(desc->dimensions); diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 8bf01bc7789..e5a39b882a4 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -939,6 +939,59 @@ TEST_CASE(reduce_contiguous_reshape_pointwise) EXPECT(p1.sort() == p2.sort()); } +// @411 = pointwise(@404), [main:pointwise256] -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, 4608000, 14400, 160, 1} +// @414 = unsqueeze[axes={1, 2, 3, 4, 5, 8, 9, 10, 11},steps={}](@408) -> float_type, {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} + +// @415 = fused_reduce[axes={7, 8, 9, 10, 11}](@411,@414), [main:pointwise257:main:reduce_sum37:main:pointwise258_reshape] -> float_type, {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} +// @419 = multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160},out_dyn_dims={}](@418) -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {32, 32, 32, 32, 32, 32, 1, 0, 0, 1, 0, 0} +// @421 = pointwise(@404,@410,@419,@420), [main:pointwise255] -> half_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, 4608000, 14400, 160, 1} + +// ops: squeeze[axes={1, 2, 3, 4, 5, 7, 8, 9, 10}], unsqueeze[axes={1, 2, 3, 4, 5, 7, 8, 10, 11},steps={}], multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160},out_dyn_dims={}] +// desc: {[1:0], [1:1], [1:2], [1:3], [1:4], [1:5], [32:6], [10:7x1], [16:], [1:7x0], [16:8], [160:11x1, 90:10, 1:11x0]} + + +TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}; + migraphx::shape s2{migraphx::shape::float_type, {8, 8, 2, 2}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x); + auto squeeze = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), rsum); + auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), squeeze); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), unsqueeze); + auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto xr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); + auto add = add_reduce( + p2, + "main:reduce_sum0_reshape:main:pointwise0", + {xr, y}, + {2, 3}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), rsum); + return add_pointwise( + p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); + }); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + TEST_CASE(reduce_reshape_reduce) { migraphx::shape s1{migraphx::shape::float_type, {2, 32, 4096}}; diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 1afc2f84d01..6544ed62b84 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -1223,4 +1223,20 @@ TEST_CASE(rebase_adjust_axes_many_moved_groups) } } +TEST_CASE(rebase_adjust_squeeze_unsqueeze_broadcast) +{ + auto base_desc = make_simple_descriptor({1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, + make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), + make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), + make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}})); + + { + auto desc = base_desc.rebase({1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); + EXPECT(not desc.empty()); + EXPECT(get_final_lens(desc) == final_lens{1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); + EXPECT(get_all_lens(desc) == all_lens{{1}, {1}, {1}, {1}, {1}, {1}, {32}, {10}, {16}, {1}, {90}, {160}}); + EXPECT(desc.generate() == ops{}); + } +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 7e2452d546d59d147c3f865beab85baae2704311 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Dec 2025 11:58:03 -0600 Subject: [PATCH 2/4] Format --- test/fuse_reduce.cpp | 41 +++++++++++++++++++---------- test/shape_transform_descriptor.cpp | 12 +++++---- 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index e5a39b882a4..862fbc47db3 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -939,16 +939,26 @@ TEST_CASE(reduce_contiguous_reshape_pointwise) EXPECT(p1.sort() == p2.sort()); } -// @411 = pointwise(@404), [main:pointwise256] -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, 4608000, 14400, 160, 1} -// @414 = unsqueeze[axes={1, 2, 3, 4, 5, 8, 9, 10, 11},steps={}](@408) -> float_type, {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} +// @411 = pointwise(@404), [main:pointwise256] -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, +// 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, 4608000, 14400, +// 160, 1} +// @414 = unsqueeze[axes={1, 2, 3, 4, 5, 8, 9, 10, 11},steps={}](@408) -> float_type, {1, 1, 1, 1, +// 1, 1, 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} -// @415 = fused_reduce[axes={7, 8, 9, 10, 11}](@411,@414), [main:pointwise257:main:reduce_sum37:main:pointwise258_reshape] -> float_type, {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} -// @419 = multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160},out_dyn_dims={}](@418) -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {32, 32, 32, 32, 32, 32, 1, 0, 0, 1, 0, 0} -// @421 = pointwise(@404,@410,@419,@420), [main:pointwise255] -> half_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, 4608000, 14400, 160, 1} - -// ops: squeeze[axes={1, 2, 3, 4, 5, 7, 8, 9, 10}], unsqueeze[axes={1, 2, 3, 4, 5, 7, 8, 10, 11},steps={}], multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160},out_dyn_dims={}] -// desc: {[1:0], [1:1], [1:2], [1:3], [1:4], [1:5], [32:6], [10:7x1], [16:], [1:7x0], [16:8], [160:11x1, 90:10, 1:11x0]} +// @415 = fused_reduce[axes={7, 8, 9, 10, 11}](@411,@414), +// [main:pointwise257:main:reduce_sum37:main:pointwise258_reshape] -> float_type, {1, 1, 1, 1, 1, 1, +// 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} +// @419 = multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160},out_dyn_dims={}](@418) +// -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {32, 32, 32, 32, 32, 32, 1, 0, 0, 1, +// 0, 0} +// @421 = pointwise(@404,@410,@419,@420), [main:pointwise255] -> half_type, {1, 1, 1, 1, 1, 1, 32, +// 10, 16, 1, 90, 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, +// 4608000, 14400, 160, 1} +// ops: squeeze[axes={1, 2, 3, 4, 5, 7, 8, 9, 10}], unsqueeze[axes={1, 2, 3, 4, 5, 7, 8, 10, +// 11},steps={}], multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, +// 160},out_dyn_dims={}] desc: {[1:0], [1:1], [1:2], [1:3], [1:4], [1:5], [32:6], [10:7x1], [16:], +// [1:7x0], [16:8], [160:11x1, 90:10, 1:11x0]} TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) { @@ -956,12 +966,15 @@ TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) migraphx::shape s2{migraphx::shape::float_type, {8, 8, 2, 2}}; migraphx::program p1; { - auto* mm = p1.get_main_module(); - auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x); - auto squeeze = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), rsum); - auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), squeeze); + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s2); + auto rsum = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x); + auto squeeze = mm->add_instruction( + migraphx::make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), rsum); + auto unsqueeze = mm->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), squeeze); auto rsumb = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), unsqueeze); auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add")); diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 6544ed62b84..1d4cb2865bf 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -1225,16 +1225,18 @@ TEST_CASE(rebase_adjust_axes_many_moved_groups) TEST_CASE(rebase_adjust_squeeze_unsqueeze_broadcast) { - auto base_desc = make_simple_descriptor({1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, - make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), - make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), - make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}})); + auto base_desc = make_simple_descriptor( + {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, + make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), + make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), + make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}})); { auto desc = base_desc.rebase({1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); EXPECT(not desc.empty()); EXPECT(get_final_lens(desc) == final_lens{1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); - EXPECT(get_all_lens(desc) == all_lens{{1}, {1}, {1}, {1}, {1}, {1}, {32}, {10}, {16}, {1}, {90}, {160}}); + EXPECT(get_all_lens(desc) == + all_lens{{1}, {1}, {1}, {1}, {1}, {1}, {32}, {10}, {16}, {1}, {90}, {160}}); EXPECT(desc.generate() == ops{}); } } From 6f01484c8ba4199fbeb5ff6104a30f89ad6148bc Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Dec 2025 18:57:49 -0600 Subject: [PATCH 3/4] Fix tests --- src/fuse_reduce.cpp | 19 +++--- src/include/migraphx/rewrite_reshapes.hpp | 3 + src/shape_transform_descriptor.cpp | 78 +++++++++++++++++++++-- test/fuse_reduce.cpp | 33 ++-------- 4 files changed, 93 insertions(+), 40 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8d1a7ff39d6..78bd37acc71 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -174,16 +174,19 @@ static auto any_input(Ms... ms) return match::any_of[match::inputs()](match::any(ms...).bind("input")); } -static bool is_valid_broadcast(const instruction_ref b, const std::vector& reduce_axes) +static bool is_valid_broadcast(const instruction_ref b, std::vector reduce_axes) { - std::vector broadcast_axes; - auto bstrides = b->get_shape().strides(); + const auto& blens = b->get_shape().lens(); + const auto& bstrides = b->get_shape().strides(); + reduce_axes.erase(std::remove_if(reduce_axes.begin(), + reduce_axes.end(), + [&](size_t axis) { return blens.at(axis) == 1; }), + reduce_axes.end()); - for(size_t i = 0; i < bstrides.size(); ++i) - { - if(bstrides.at(i) == 0) - broadcast_axes.push_back(i); - } + std::vector broadcast_axes; + copy_if(range(bstrides.size()), + std::back_inserter(broadcast_axes), + [&](size_t i) { return bstrides.at(i) == 0 and blens.at(i) != 1; }); return broadcast_axes == reduce_axes; } diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index f506addc3a7..263d872bc4c 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -163,6 +163,9 @@ struct rewrite_reshapes if(desc.empty()) return; + if(desc.elements() != elements(dims2)) + return; + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, const auto& gdesc) { return [&](auto input) { diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 2743dcd2a01..1664a5af905 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -339,6 +339,9 @@ struct rebase_ambiguity_resolver if(shortage_axes.empty()) return axes_map; + if(try_trivial_direct_mapping()) + return regroup_axes(); + process_axis_groups(axes_map, subs_to_insert); if(shortage_axes.size() == initial_shortage_count) @@ -351,10 +354,7 @@ struct rebase_ambiguity_resolver sort_hidden_axes_groups(); sort_moved_axes_groups(); - axes_map_t regroup_axes = group_axes(desc->dimensions); - renumber_axes(regroup_axes); - - return regroup_axes; + return regroup_axes(); } private: @@ -368,6 +368,13 @@ struct rebase_ambiguity_resolver return x / y; } + axes_map_t regroup_axes() + { + axes_map_t result = group_axes(desc->dimensions); + renumber_axes(result); + return result; + } + // Identifies axes where the target dimension is larger than current subdimensions // These are "shortage" axes that need subdimensions due to ambiguous axis assignment void find_shortage_axes(const axes_map_t& axes_map) @@ -384,6 +391,69 @@ struct rebase_ambiguity_resolver initial_shortage_count = shortage_axes.size(); } + bool try_trivial_direct_mapping() + { + if(desc->lens() != *dims) + return false; + if(not std::all_of(desc->dimensions.begin(), desc->dimensions.end(), [&](const dimension& d) { + if(d.subdimensions.empty()) + return false; + if(d.len() == 1) + return true; + if(std::any_of(d.subdimensions.begin(), d.subdimensions.end(), [&](const dimension::sub& s) { + if(s.origin_axis().empty()) + return false; + if(s.origin_axis().size() != 1) + return true; + if(s.len == 1) + return false; + if(s.has_hidden_axis()) + return false; + return ((*dims)[s.origin_axis().front()] != s.len); + })) + return false; + if(d.subdimensions.size() == 1) + return true; + auto n1dims = std::count_if(d.subdimensions.begin(), + d.subdimensions.end(), + [](const dimension::sub& s) { return s.len == 1; }); + return n1dims + 1 == d.subdimensions.size(); + })) + return false; + std::vector axes; + for_each_subdimension(desc->dimensions, [&](auto& s) { + if(s.origin_axis().empty()) + return; + axes.push_back(s.origin_axis().front()); + }); + // TODO: Handle permutations + if(not std::is_sorted(axes.begin(), axes.end())) + return false; + for(std::size_t i:range(desc->dimensions.size())) + { + auto& dim = desc->dimensions[i]; + if(dim.subdimensions.empty()) + continue; + auto sub = std::find_if(dim.subdimensions.begin(), dim.subdimensions.end(), [&](const dimension::sub& s) { + return s.len != 1; + }); + if(sub ==dim.subdimensions.end()) + sub = dim.subdimensions.begin(); + sub->expose(); + sub->axis = {i}; + + auto remove_axis = [](dimension::sub& s) { + s.axis.clear(); + s.hidden_axis.clear(); + s.len = 1; + }; + std::for_each(dim.subdimensions.begin(), sub, remove_axis); + std::for_each(std::next(sub), dim.subdimensions.end(), remove_axis); + } + shortage_axes.clear(); + return true; + } + // Processes each axis group to resolve ambiguous axis assignments // This is the core logic that fixes mismatches from reshape ambiguity // diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 862fbc47db3..98f69b6fef5 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -939,36 +939,14 @@ TEST_CASE(reduce_contiguous_reshape_pointwise) EXPECT(p1.sort() == p2.sort()); } -// @411 = pointwise(@404), [main:pointwise256] -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, -// 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, 4608000, 14400, -// 160, 1} -// @414 = unsqueeze[axes={1, 2, 3, 4, 5, 8, 9, 10, 11},steps={}](@408) -> float_type, {1, 1, 1, 1, -// 1, 1, 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} - -// @415 = fused_reduce[axes={7, 8, 9, 10, 11}](@411,@414), -// [main:pointwise257:main:reduce_sum37:main:pointwise258_reshape] -> float_type, {1, 1, 1, 1, 1, 1, -// 32, 1, 1, 1, 1, 1}, {32, 32, 32, 32, 32, 32, 1, 1, 1, 1, 1, 1} -// @419 = multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160},out_dyn_dims={}](@418) -// -> float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}, {32, 32, 32, 32, 32, 32, 1, 0, 0, 1, -// 0, 0} -// @421 = pointwise(@404,@410,@419,@420), [main:pointwise255] -> half_type, {1, 1, 1, 1, 1, 1, 32, -// 10, 16, 1, 90, 160}, {73728000, 73728000, 73728000, 73728000, 73728000, 73728000, 144000, 14400, -// 4608000, 14400, 160, 1} - -// ops: squeeze[axes={1, 2, 3, 4, 5, 7, 8, 9, 10}], unsqueeze[axes={1, 2, 3, 4, 5, 7, 8, 10, -// 11},steps={}], multibroadcast[out_lens={1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, -// 160},out_dyn_dims={}] desc: {[1:0], [1:1], [1:2], [1:3], [1:4], [1:5], [32:6], [10:7x1], [16:], -// [1:7x0], [16:8], [160:11x1, 90:10, 1:11x0]} - TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) { migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}; - migraphx::shape s2{migraphx::shape::float_type, {8, 8, 2, 2}}; migraphx::program p1; { auto* mm = p1.get_main_module(); auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); + auto y = mm->add_parameter("y", s1); auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x); auto squeeze = mm->add_instruction( @@ -985,18 +963,17 @@ TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) { auto* mm = p2.get_main_module(); auto x = mm->add_parameter("x", s1); - auto y = mm->add_parameter("y", s2); - auto xr = mm->add_instruction(migraphx::make_op("reshape", {{"dims", s2.lens()}}), x); + auto y = mm->add_parameter("y", s1); auto add = add_reduce( p2, "main:reduce_sum0_reshape:main:pointwise0", - {xr, y}, - {2, 3}, + {x, y}, + {7, 8, 9, 10, 11}, [&](auto* rm, const auto& inputs, const auto& axes) { auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), inputs[0]); auto rsumb = rm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s2.lens()}}), rsum); + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); return add_pointwise( p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); }); From a6c2a8d2dd08f48f4b3c212119d56354a61fcfb0 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 4 Dec 2025 18:57:52 -0600 Subject: [PATCH 4/4] Format --- src/fuse_reduce.cpp | 8 ++-- src/shape_transform_descriptor.cpp | 61 ++++++++++++++++-------------- 2 files changed, 36 insertions(+), 33 deletions(-) diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 78bd37acc71..f3147b9cc1d 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -176,7 +176,7 @@ static auto any_input(Ms... ms) static bool is_valid_broadcast(const instruction_ref b, std::vector reduce_axes) { - const auto& blens = b->get_shape().lens(); + const auto& blens = b->get_shape().lens(); const auto& bstrides = b->get_shape().strides(); reduce_axes.erase(std::remove_if(reduce_axes.begin(), reduce_axes.end(), @@ -184,9 +184,9 @@ static bool is_valid_broadcast(const instruction_ref b, std::vector redu reduce_axes.end()); std::vector broadcast_axes; - copy_if(range(bstrides.size()), - std::back_inserter(broadcast_axes), - [&](size_t i) { return bstrides.at(i) == 0 and blens.at(i) != 1; }); + copy_if(range(bstrides.size()), std::back_inserter(broadcast_axes), [&](size_t i) { + return bstrides.at(i) == 0 and blens.at(i) != 1; + }); return broadcast_axes == reduce_axes; } diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index 1664a5af905..55e86fa49ee 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -395,30 +395,33 @@ struct rebase_ambiguity_resolver { if(desc->lens() != *dims) return false; - if(not std::all_of(desc->dimensions.begin(), desc->dimensions.end(), [&](const dimension& d) { - if(d.subdimensions.empty()) - return false; - if(d.len() == 1) - return true; - if(std::any_of(d.subdimensions.begin(), d.subdimensions.end(), [&](const dimension::sub& s) { - if(s.origin_axis().empty()) - return false; - if(s.origin_axis().size() != 1) - return true; - if(s.len == 1) - return false; - if(s.has_hidden_axis()) - return false; - return ((*dims)[s.origin_axis().front()] != s.len); - })) - return false; - if(d.subdimensions.size() == 1) - return true; - auto n1dims = std::count_if(d.subdimensions.begin(), - d.subdimensions.end(), - [](const dimension::sub& s) { return s.len == 1; }); - return n1dims + 1 == d.subdimensions.size(); - })) + if(not std::all_of( + desc->dimensions.begin(), desc->dimensions.end(), [&](const dimension& d) { + if(d.subdimensions.empty()) + return false; + if(d.len() == 1) + return true; + if(std::any_of(d.subdimensions.begin(), + d.subdimensions.end(), + [&](const dimension::sub& s) { + if(s.origin_axis().empty()) + return false; + if(s.origin_axis().size() != 1) + return true; + if(s.len == 1) + return false; + if(s.has_hidden_axis()) + return false; + return ((*dims)[s.origin_axis().front()] != s.len); + })) + return false; + if(d.subdimensions.size() == 1) + return true; + auto n1dims = std::count_if(d.subdimensions.begin(), + d.subdimensions.end(), + [](const dimension::sub& s) { return s.len == 1; }); + return n1dims + 1 == d.subdimensions.size(); + })) return false; std::vector axes; for_each_subdimension(desc->dimensions, [&](auto& s) { @@ -429,15 +432,15 @@ struct rebase_ambiguity_resolver // TODO: Handle permutations if(not std::is_sorted(axes.begin(), axes.end())) return false; - for(std::size_t i:range(desc->dimensions.size())) + for(std::size_t i : range(desc->dimensions.size())) { auto& dim = desc->dimensions[i]; if(dim.subdimensions.empty()) continue; - auto sub = std::find_if(dim.subdimensions.begin(), dim.subdimensions.end(), [&](const dimension::sub& s) { - return s.len != 1; - }); - if(sub ==dim.subdimensions.end()) + auto sub = std::find_if(dim.subdimensions.begin(), + dim.subdimensions.end(), + [&](const dimension::sub& s) { return s.len != 1; }); + if(sub == dim.subdimensions.end()) sub = dim.subdimensions.begin(); sub->expose(); sub->axis = {i};