Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,16 +174,19 @@ static auto any_input(Ms... ms)
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
}

static bool is_valid_broadcast(const instruction_ref b, const std::vector<size_t>& reduce_axes)
static bool is_valid_broadcast(const instruction_ref b, std::vector<size_t> reduce_axes)
{
std::vector<size_t> broadcast_axes;
auto bstrides = b->get_shape().strides();
const auto& blens = b->get_shape().lens();
const auto& bstrides = b->get_shape().strides();
reduce_axes.erase(std::remove_if(reduce_axes.begin(),
reduce_axes.end(),
[&](size_t axis) { return blens.at(axis) == 1; }),
reduce_axes.end());

for(size_t i = 0; i < bstrides.size(); ++i)
{
if(bstrides.at(i) == 0)
broadcast_axes.push_back(i);
}
std::vector<size_t> broadcast_axes;
copy_if(range(bstrides.size()), std::back_inserter(broadcast_axes), [&](size_t i) {
return bstrides.at(i) == 0 and blens.at(i) != 1;
});

return broadcast_axes == reduce_axes;
}
Expand Down
3 changes: 3 additions & 0 deletions src/include/migraphx/rewrite_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ struct rewrite_reshapes
if(desc.empty())
return;

if(desc.elements() != elements(dims2))
Copy link
Collaborator

@CharlieL7 CharlieL7 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the Llama3.2 issue, it looks like this line just prevents the rewrite. The gather before fused_reduce is {1, 1, 2048}, {2048, 2048, 1} while the output of the fused_reduce + reshapes is {1, 1, 64, 1, 32}, {1, 1, 0, 1, 0}. To move the reshape instructions from after the fused_reduce to before we would have to unsqueeze and broadcast the gather to something like {1 , 1, 64, 2048, 32}, {0, 0, 0, 1, 0}.

The shape_transform_descriptor after rebase with the bugged code is {[batch_size: 0], [1: 1], [64,:], [2048:2], [32:]}. Mentioned that this was incorrect, but it does look right? Since the 64 and 32 dimensions are broadcasted dimensions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the Llama3.2 issue, it looks like this line just prevents the rewrite.

Yes it will. Once we have the logger, I would like to log a message for these cases because it means we are missing some perf issues.

The gather before fused_reduce is {1, 1, 2048}, {2048, 2048, 1} while the output of the fused_reduce + reshapes is {1, 1, 64, 1, 32}, {1, 1, 0, 1, 0}.

What are the reshapes ops being used? You can print the ops vector. I think we mght need that to reproduce this issue.

Mentioned that this was incorrect, but it does look right? Since the 64 and 32 dimensions are broadcasted dimensions.

No its not right because we arent broadcasting on the input to pointwise.

We start with {1, 1, 2048}, after reduce its {1, 1, 1}, then its reshaped(or uses unsqueeze) to {1, 1, 1, 1, 1} which is then broadcasted to {1, 1, 64, 1, 32}.

We start with {1, 1, 1} arriving to {1, 1, 64, 1, 32}with the shape transform descriptor and then we rebase it with the{1, 1, 2048}so it starts with that instead of{1, 1, 1}` because we want to move the transformations before the reduce so that we can fuse them together.

So we want {1, 1, 2048} reshaped to {1, 1, 64, 1, 32} and the reduction to happen along the last 3 axes, but the descriptor is showing a reshape to {1, 1, 1, 2048, 1} and then broadcasted to {1, 1, 64, 2048, 32} which is wrong.

Copy link
Collaborator

@CharlieL7 CharlieL7 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we want {1, 1, 2048} reshaped to {1, 1, 64, 1, 32} and the reduction to happen along the last 3 axes, but the descriptor is showing a reshape to {1, 1, 1, 2048, 1} and then broadcasted to {1, 1, 64, 2048, 32} which is wrong.

{1, 1, 64, 1, 32} would not work as the instruction before the reduce though? The reduction has to occur on the 2048, so with {1, 1, 64, 1, 32} the reduction shape output would be {1, 1, 1, 1, 1} which would have to be broadcasted again. I don't get what benefit there would be.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the above what we want is this?:

TEST_CASE(rebase_broadcasted_scalar)
{
    // Taken from bug found when compiling Llama3.2
    auto base_desc =
        make_simple_descriptor({1, 1, 1},
                               make_op("unsqueeze", {{"axes", {2, 4}}}),
                               make_op("multibroadcast", {{"out_lens", {1, 1, 64, 1, 32}}}));

    {
        auto desc = base_desc.rebase({1, 1, 2048});
        EXPECT(not desc.empty());
        EXPECT(get_final_lens(desc) == final_lens{1, 1, 64, 1, 32});
        EXPECT(get_all_lens(desc) == all_lens{{1}, {1}, {64}, {1}, {32}});
        EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}});
        auto generated = desc.generate();
        EXPECT(generated ==
               ops{
                   make_op("reshape", {{"out_lens", {1, 1, 64, 1, 32}}}),
               });
    }
}

Copy link
Collaborator Author

@pfultz2 pfultz2 Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the above what we want is this?:

Yea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{1, 1, 64, 1, 32} would not work as the instruction before the reduce though? The reduction has to occur on the 2048, so with {1, 1, 64, 1, 32} the reduction shape output would be {1, 1, 1, 1, 1} which would have to be broadcasted again. I don't get what benefit there would be.

Currently there is a reduce -> unsqueeze -> broadcast -> pointwise, which we dont fuse because of the unsqueeze. After the rewrite we will have reshape -> reduce -> broadcast -> pointwise which we can then fuse the reduce -> broadcast -> pointwise.

return;

auto cdims = desc.common_dims();
auto reshape_input = [&](const auto& ins_to_insert, const auto& gdesc) {
return [&](auto input) {
Expand Down
82 changes: 77 additions & 5 deletions src/shape_transform_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ struct rebase_ambiguity_resolver
// Returns the axes mapping that can be used for rebase
auto resolve()
{

std::vector<std::pair<dimension::sub, std::size_t>> subs_to_insert;
{
axes_map_t axes_map = group_axes(desc->dimensions);
Expand All @@ -340,6 +339,9 @@ struct rebase_ambiguity_resolver
if(shortage_axes.empty())
return axes_map;

if(try_trivial_direct_mapping())
return regroup_axes();

process_axis_groups(axes_map, subs_to_insert);

if(shortage_axes.size() == initial_shortage_count)
Expand All @@ -352,10 +354,7 @@ struct rebase_ambiguity_resolver
sort_hidden_axes_groups();
sort_moved_axes_groups();

axes_map_t regroup_axes = group_axes(desc->dimensions);
renumber_axes(regroup_axes);

return regroup_axes;
return regroup_axes();
}

private:
Expand All @@ -369,6 +368,13 @@ struct rebase_ambiguity_resolver
return x / y;
}

axes_map_t regroup_axes()
{
axes_map_t result = group_axes(desc->dimensions);
renumber_axes(result);
return result;
}

// Identifies axes where the target dimension is larger than current subdimensions
// These are "shortage" axes that need subdimensions due to ambiguous axis assignment
void find_shortage_axes(const axes_map_t& axes_map)
Expand All @@ -385,6 +391,72 @@ struct rebase_ambiguity_resolver
initial_shortage_count = shortage_axes.size();
}

bool try_trivial_direct_mapping()
{
if(desc->lens() != *dims)
return false;
if(not std::all_of(
desc->dimensions.begin(), desc->dimensions.end(), [&](const dimension& d) {
if(d.subdimensions.empty())
return false;
if(d.len() == 1)
return true;
if(std::any_of(d.subdimensions.begin(),
d.subdimensions.end(),
[&](const dimension::sub& s) {
if(s.origin_axis().empty())
return false;
if(s.origin_axis().size() != 1)
return true;
if(s.len == 1)
return false;
if(s.has_hidden_axis())
return false;
return ((*dims)[s.origin_axis().front()] != s.len);
}))
return false;
if(d.subdimensions.size() == 1)
return true;
auto n1dims = std::count_if(d.subdimensions.begin(),
d.subdimensions.end(),
[](const dimension::sub& s) { return s.len == 1; });
return n1dims + 1 == d.subdimensions.size();
}))
return false;
std::vector<std::size_t> axes;
for_each_subdimension(desc->dimensions, [&](auto& s) {
if(s.origin_axis().empty())
return;
axes.push_back(s.origin_axis().front());
});
// TODO: Handle permutations
if(not std::is_sorted(axes.begin(), axes.end()))
return false;
for(std::size_t i : range(desc->dimensions.size()))
{
auto& dim = desc->dimensions[i];
if(dim.subdimensions.empty())
continue;
auto sub = std::find_if(dim.subdimensions.begin(),
dim.subdimensions.end(),
[&](const dimension::sub& s) { return s.len != 1; });
if(sub == dim.subdimensions.end())
sub = dim.subdimensions.begin();
sub->expose();
sub->axis = {i};

auto remove_axis = [](dimension::sub& s) {
s.axis.clear();
s.hidden_axis.clear();
s.len = 1;
};
std::for_each(dim.subdimensions.begin(), sub, remove_axis);
std::for_each(std::next(sub), dim.subdimensions.end(), remove_axis);
}
shortage_axes.clear();
return true;
}

// Processes each axis group to resolve ambiguous axis assignments
// This is the core logic that fixes mismatches from reshape ambiguity
//
Expand Down
43 changes: 43 additions & 0 deletions test/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,49 @@ TEST_CASE(reduce_contiguous_reshape_pointwise)
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(reduce_squeeze_unsqueeze_pointwise1)
{
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto rsum =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {7, 8, 9, 10, 11}}}), x);
auto squeeze = mm->add_instruction(
migraphx::make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}), rsum);
auto unsqueeze = mm->add_instruction(
migraphx::make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}), squeeze);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), unsqueeze);
auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s1);
auto add = add_reduce(
p2,
"main:reduce_sum0_reshape:main:pointwise0",
{x, y},
{7, 8, 9, 10, 11},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum);
return add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add"));
});
mm->add_return({add});
}
EXPECT(p1 == p2);
}

TEST_CASE(reduce_reshape_reduce)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 32, 4096}};
Expand Down
18 changes: 18 additions & 0 deletions test/shape_transform_descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1223,4 +1223,22 @@ TEST_CASE(rebase_adjust_axes_many_moved_groups)
}
}

TEST_CASE(rebase_adjust_squeeze_unsqueeze_broadcast)
{
auto base_desc = make_simple_descriptor(
{1, 1, 1, 1, 1, 1, 32, 1, 1, 1, 1, 1},
make_op("squeeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 9, 10}}}),
make_op("unsqueeze", {{"axes", {1, 2, 3, 4, 5, 7, 8, 10, 11}}}),
make_op("multibroadcast", {{"out_lens", {1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160}}}));

{
auto desc = base_desc.rebase({1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160});
EXPECT(not desc.empty());
EXPECT(get_final_lens(desc) == final_lens{1, 1, 1, 1, 1, 1, 32, 10, 16, 1, 90, 160});
EXPECT(get_all_lens(desc) ==
all_lens{{1}, {1}, {1}, {1}, {1}, {1}, {32}, {10}, {16}, {1}, {90}, {160}});
EXPECT(desc.generate() == ops{});
}
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading