Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
Full documentation for MIGraphX is available at
[https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/).


## Develop

### Added

### Changed

### Resolved issues
* Fixed BF16/FP16 Quantization failure via zero element concats (#4512).

### Optimized

### Removed





## MIGraphX 2.15 for ROCm 7.2.0

### Added
Expand Down
9 changes: 8 additions & 1 deletion src/include/migraphx/op/concat.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -90,6 +90,13 @@ struct concat
{
if(ll != axis)
{
// Skip if any input dimension is 0 since we'll optimize this out later
if(std::any_of(
inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == 0; }))
{
continue;
}

if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) {
return s.lens()[ll] == first_shape_lens[ll];
}))
Expand Down
45 changes: 45 additions & 0 deletions src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,50 @@ struct find_concat_multibroadcasts
}
};

struct find_concat_zero_element_inputs
{
auto matcher() const
{
return match::name("concat")(match::any_of[match::inputs()](match::make_basic_pred_matcher(
[](const auto& ins) { return ins->get_shape().elements() != 0; })));
}

void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto inputs = ins->inputs();

std::vector<instruction_ref> new_inputs;

// Filter inputs that don't have zero element shapes
migraphx::transform_if(
inputs.begin(),
inputs.end(),
std::back_inserter(new_inputs),
[&](const auto& in) { return in->get_shape().elements() != 0; },
[&](const auto& in) { return in; });

// Replace old concat with updated concat with updated inputs
if(new_inputs.empty())
{
m.remove_instruction(ins);
}
else if(new_inputs.size() < inputs.size())
{
if(new_inputs.size() == 1)
{
m.replace_instruction(ins, new_inputs.front());
}
else
{
auto op = any_cast<op::concat>(ins->get_operator());
auto concat = m.insert_instruction(ins, op, new_inputs);
m.replace_instruction(ins, concat);
}
}
}
};

struct find_concat_slice
{
auto matcher() const
Expand Down Expand Up @@ -1434,6 +1478,7 @@ void simplify_reshapes::apply(module& m) const
find_concat_multibroadcasts{},
find_nested_slice{},
find_nested_concat{},
find_concat_zero_element_inputs{},
find_transpose_slice{},
find_slice_transpose{},
find_unary_shape_transforms{},
Expand Down
232 changes: 232 additions & 0 deletions test/simplify_reshapes_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3661,4 +3661,236 @@ TEST_CASE(conv_add_layernorm_conv)
EXPECT(m1.sort() == m2.sort());
}

// TODO: Add test case for all zero inputs and determine how we want to hanlde this
/*TEST_CASE(concat_zero_element_inputs_all_zero)
{
// Test case where all inputs have zero elements - concat should be removed
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
m1.add_return({concat});
}
run_pass(m1);

// After optimization, concat should be removed
EXPECT(std::count_if(m1.begin(), m1.end(), [](auto ins) { return ins.name() == "concat"; }) ==
0);
} */

TEST_CASE(concat_zero_element_inputs_some_zero)
{
// Test case where some inputs have zero elements - they should be filtered out
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not be creating shapes with zero elements. This is a problem in the onnx parser. It should not be fixed in the optimization passes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I thought at first but this is actually valid onnx and infact 0 is a valid index /input to the ConstantOfShape call which is what's used here in conjunction with concat.

Customer models we've been given are using zero input shapes to stub out inputs/parameters though and this pattern of concats is used to aggregate the inputs to make the final input.

We can't ask them to change the model as this is how they're toggling/flipping between parameters being removed from the model.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should still be fixed in the onnx parser since our operators are not written to handle empty elements. We can extend parse_generic to skip zero element tensors and then if the arguments are empty return an undefined instruction instead.

auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s3);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto x = m2.add_parameter("x", s1);
m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s3);
// Only x and z should remain in the concat (y is filtered out)
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z);
m2.add_return({concat});
}

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(concat_zero_element_inputs_first_zero)
{
// Test case where the first input has zero elements
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s3);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s3);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, z);
m2.add_return({concat});
}

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(concat_zero_element_inputs_last_zero)
{
// Test case where the last input has zero elements
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s3);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto x = m2.add_parameter("x", s1);
auto y = m2.add_parameter("y", s2);
m2.add_parameter("z", s3);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y);
m2.add_return({concat});
}

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(concat_zero_element_inputs_multiple_zero)
{
// Test case where multiple inputs have zero elements
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto w = m1.add_parameter("w", s1);
auto x = m1.add_parameter("x", s2);
auto y = m1.add_parameter("y", s3);
auto z = m1.add_parameter("z", s4);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w, x, y, z);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
m2.add_parameter("w", s1);
auto x = m2.add_parameter("x", s2);
m2.add_parameter("y", s3);
auto z = m2.add_parameter("z", s4);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z);
m2.add_return({concat});
}

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(concat_zero_element_inputs_no_zero)
{
// Test case where no inputs have zero elements - concat should remain unchanged
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 7, 4}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s3);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
m1.add_return({concat});
}
auto original_m1 = m1;
run_pass(m1);

// Module should remain unchanged
EXPECT(m1.sort() == original_m1.sort());
}

TEST_CASE(concat_zero_element_inputs_different_axis)
{
// Test case with zero elements on a different axis
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s3);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}};
auto x = m2.add_parameter("x", s1);
m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s3);
auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, z);
m2.add_return({concat});
}

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(concat_zero_element_one_input)
{
// Test case with zero elements on a different axis
migraphx::module m1;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
auto x = m1.add_parameter("x", s1);
auto y = m1.add_parameter("y", s2);
auto z = m1.add_parameter("z", s3);
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z);
m1.add_return({concat});
}
run_pass(m1);

migraphx::module m2;
{
auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}};
auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}};
auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}};
m2.add_parameter("x", s1);
m2.add_parameter("y", s2);
auto z = m2.add_parameter("z", s3);

m2.add_return({z});
}

EXPECT(m1.sort() == m2.sort());
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }
Loading