diff --git a/src/fuse_reduce.cpp b/src/fuse_reduce.cpp index 8d1a7ff39d6..f3147b9cc1d 100644 --- a/src/fuse_reduce.cpp +++ b/src/fuse_reduce.cpp @@ -174,16 +174,19 @@ static auto any_input(Ms... ms) return match::any_of[match::inputs()](match::any(ms...).bind("input")); } -static bool is_valid_broadcast(const instruction_ref b, const std::vector& reduce_axes) +static bool is_valid_broadcast(const instruction_ref b, std::vector reduce_axes) { - std::vector broadcast_axes; - auto bstrides = b->get_shape().strides(); + const auto& blens = b->get_shape().lens(); + const auto& bstrides = b->get_shape().strides(); + reduce_axes.erase(std::remove_if(reduce_axes.begin(), + reduce_axes.end(), + [&](size_t axis) { return blens.at(axis) == 1; }), + reduce_axes.end()); - for(size_t i = 0; i < bstrides.size(); ++i) - { - if(bstrides.at(i) == 0) - broadcast_axes.push_back(i); - } + std::vector broadcast_axes; + copy_if(range(bstrides.size()), std::back_inserter(broadcast_axes), [&](size_t i) { + return bstrides.at(i) == 0 and blens.at(i) != 1; + }); return broadcast_axes == reduce_axes; } diff --git a/src/include/migraphx/rewrite_reshapes.hpp b/src/include/migraphx/rewrite_reshapes.hpp index f506addc3a7..263d872bc4c 100644 --- a/src/include/migraphx/rewrite_reshapes.hpp +++ b/src/include/migraphx/rewrite_reshapes.hpp @@ -163,6 +163,9 @@ struct rewrite_reshapes if(desc.empty()) return; + if(desc.elements() != elements(dims2)) + return; + auto cdims = desc.common_dims(); auto reshape_input = [&](const auto& ins_to_insert, const auto& gdesc) { return [&](auto input) { diff --git a/src/shape_transform_descriptor.cpp b/src/shape_transform_descriptor.cpp index aac7170cb4d..55e86fa49ee 100644 --- a/src/shape_transform_descriptor.cpp +++ b/src/shape_transform_descriptor.cpp @@ -330,7 +330,6 @@ struct rebase_ambiguity_resolver // Returns the axes mapping that can be used for rebase auto resolve() { - std::vector> subs_to_insert; { axes_map_t axes_map = group_axes(desc->dimensions); @@ -340,6 +339,9 @@ struct rebase_ambiguity_resolver if(shortage_axes.empty()) return axes_map; + if(try_trivial_direct_mapping()) + return regroup_axes(); + process_axis_groups(axes_map, subs_to_insert); if(shortage_axes.size() == initial_shortage_count) @@ -352,10 +354,7 @@ struct rebase_ambiguity_resolver sort_hidden_axes_groups(); sort_moved_axes_groups(); - axes_map_t regroup_axes = group_axes(desc->dimensions); - renumber_axes(regroup_axes); - - return regroup_axes; + return regroup_axes(); } private: @@ -369,6 +368,13 @@ struct rebase_ambiguity_resolver return x / y; } + axes_map_t regroup_axes() + { + axes_map_t result = group_axes(desc->dimensions); + renumber_axes(result); + return result; + } + // Identifies axes where the target dimension is larger than current subdimensions // These are "shortage" axes that need subdimensions due to ambiguous axis assignment void find_shortage_axes(const axes_map_t& axes_map) @@ -385,6 +391,72 @@ struct rebase_ambiguity_resolver initial_shortage_count = shortage_axes.size(); } + bool try_trivial_direct_mapping() + { + if(desc->lens() != *dims) + return false; + if(not std::all_of( + desc->dimensions.begin(), desc->dimensions.end(), [&](const dimension& d) { + if(d.subdimensions.empty()) + return false; + if(d.len() == 1) + return true; + if(std::any_of(d.subdimensions.begin(), + d.subdimensions.end(), + [&](const dimension::sub& s) { + if(s.origin_axis().empty()) + return false; + if(s.origin_axis().size() != 1) + return true; + if(s.len == 1) + return false; + if(s.has_hidden_axis()) + return false; + return ((*dims)[s.origin_axis().front()] != s.len); + })) + return false; + if(d.subdimensions.size() == 1) + return true; + auto n1dims = std::count_if(d.subdimensions.begin(), + d.subdimensions.end(), + [](const dimension::sub& s) { return s.len == 1; }); + return n1dims + 1 == d.subdimensions.size(); + })) + return false; + std::vector axes; + for_each_subdimension(desc->dimensions, [&](auto& s) { + if(s.origin_axis().empty()) + return; + axes.push_back(s.origin_axis().front()); + }); + // TODO: Handle permutations + if(not std::is_sorted(axes.begin(), axes.end())) + return false; + for(std::size_t i : range(desc->dimensions.size())) + { + auto& dim = desc->dimensions[i]; + if(dim.subdimensions.empty()) + continue; + auto sub = std::find_if(dim.subdimensions.begin(), + dim.subdimensions.end(), + [&](const dimension::sub& s) { return s.len != 1; }); + if(sub == dim.subdimensions.end()) + sub = dim.subdimensions.begin(); + sub->expose(); + sub->axis = {i}; + + auto remove_axis = [](dimension::sub& s) { + s.axis.clear(); + s.hidden_axis.clear(); + s.len = 1; + }; + std::for_each(dim.subdimensions.begin(), sub, remove_axis); + std::for_each(std::next(sub), dim.subdimensions.end(), remove_axis); + } + shortage_axes.clear(); + return true; + } + // Processes each axis group to resolve ambiguous axis assignments // This is the core logic that fixes mismatches from reshape ambiguity // diff --git a/test/fuse_reduce.cpp b/test/fuse_reduce.cpp index 8bf01bc7789..98f69b6fef5 100644 --- a/test/fuse_reduce.cpp +++ b/test/fuse_reduce.cpp @@ -939,6 +939,49 @@ TEST_CASE(reduce_contiguous_reshape_pointwise) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(reduce_squeeze_unsqueeze_pointwise1) +{ + migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}; + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto rsum = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x); + auto squeeze = mm->add_instruction( + migraphx::make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), rsum); + auto unsqueeze = mm->add_instruction( + migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), squeeze); + auto rsumb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), unsqueeze); + auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add")); + mm->add_return({add}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", s1); + auto y = mm->add_parameter("y", s1); + auto add = add_reduce( + p2, + "main:reduce_sum0_reshape:main:pointwise0", + {x, y}, + {7, 8, 9, 10, 11}, + [&](auto* rm, const auto& inputs, const auto& axes) { + auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), + inputs[0]); + auto rsumb = rm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); + return add_pointwise( + p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add")); + }); + mm->add_return({add}); + } + EXPECT(p1 == p2); +} + TEST_CASE(reduce_reshape_reduce) { migraphx::shape s1{migraphx::shape::float_type, {2, 32, 4096}}; diff --git a/test/shape_transform_descriptor.cpp b/test/shape_transform_descriptor.cpp index 1afc2f84d01..1d4cb2865bf 100644 --- a/test/shape_transform_descriptor.cpp +++ b/test/shape_transform_descriptor.cpp @@ -1223,4 +1223,22 @@ TEST_CASE(rebase_adjust_axes_many_moved_groups) } } +TEST_CASE(rebase_adjust_squeeze_unsqueeze_broadcast) +{ + auto base_desc = make_simple_descriptor( + {1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1}, + make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), + make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), + make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}})); + + { + auto desc = base_desc.rebase({1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); + EXPECT(not desc.empty()); + EXPECT(get_final_lens(desc) == final_lens{1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}); + EXPECT(get_all_lens(desc) == + all_lens{{1}, {1}, {1}, {1}, {1}, {1}, {32}, {10}, {16}, {1}, {90}, {160}}); + EXPECT(desc.generate() == ops{}); + } +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }