From 6eba8ef5c791c86f08a6b0c569be2dc6ddaf457a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Wed, 17 Dec 2025 04:28:34 +0000 Subject: [PATCH 01/17] Add matcher for empty literal and concats remove empty element shaped literals remove inputs from concat that contain inputs with zero element size as the concat is irrelevant --- src/simplify_reshapes.cpp | 66 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 51b5e199db5..02c7adbd337 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -574,6 +574,70 @@ struct find_concat_multibroadcasts } }; +struct find_zero_element_literal +{ + auto matcher() const + { + return match::name("@literal"); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto lit = mr.result; + auto s = lit->get_shape(); + + if(s.elements() == 0) + { + std::cout << "Found zero element literal" << std::endl; + lit->debug_print(); + m.remove_instruction(lit); + } + } +}; + +struct find_concat_zero_element_inputs +{ + auto matcher() const + { + return match::name("concat"); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto inputs = ins->inputs(); + auto outs = ins->outputs(); + auto op = any_cast(ins->get_operator()); + + std::vector new_inputs; + + // Filter inputs that don't have zero element shapes + migraphx::transform_if( + inputs.begin(), + inputs.end(), + std::back_inserter(new_inputs), + [&](const auto& in) { return in->get_shape().elements() != 0; }, + [&](const auto& in) { return in; }); + + + // Replace old concat with updated concat with updated inputs + if(new_inputs.size() == 0) + { + std::cout << "found instruction to remove" << std::endl; + ins->debug_print(); + m.remove_instruction(ins); + } + else if (new_inputs.size() < inputs.size()) + { + std::cout << "found instruction to replace" << std::endl; + ins->debug_print(); + auto concat = m.insert_instruction(ins, op, new_inputs); + concat->debug_print(); + m.replace_instruction(ins, concat); + } + } +}; + struct find_concat_slice { auto matcher() const @@ -1422,6 +1486,7 @@ void simplify_reshapes::apply(module& m) const { m.repeat_while_changes(depth, [&] { match::find_matches(m, + find_zero_element_literal{}, find_where_op{}, find_resize{}, find_nop_reshapes{}, @@ -1432,6 +1497,7 @@ void simplify_reshapes::apply(module& m) const find_concat_transpose{}, find_concat_reshape{}, find_concat_multibroadcasts{}, + find_concat_zero_element_inputs{}, find_nested_slice{}, find_nested_concat{}, find_transpose_slice{}, From d107dc0b1e984707110e094ad40fdf5d8deb7d19 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 18 Dec 2025 18:51:59 +0000 Subject: [PATCH 02/17] Comment out literal find matcher --- src/simplify_reshapes.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 02c7adbd337..30c1dc212cc 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -574,7 +574,7 @@ struct find_concat_multibroadcasts } }; -struct find_zero_element_literal +/*struct find_zero_element_literal { auto matcher() const { @@ -593,7 +593,7 @@ struct find_zero_element_literal m.remove_instruction(lit); } } -}; +}; */ struct find_concat_zero_element_inputs { @@ -1486,7 +1486,7 @@ void simplify_reshapes::apply(module& m) const { m.repeat_while_changes(depth, [&] { match::find_matches(m, - find_zero_element_literal{}, + //find_zero_element_literal{}, find_where_op{}, find_resize{}, find_nop_reshapes{}, @@ -1497,9 +1497,9 @@ void simplify_reshapes::apply(module& m) const find_concat_transpose{}, find_concat_reshape{}, find_concat_multibroadcasts{}, - find_concat_zero_element_inputs{}, find_nested_slice{}, find_nested_concat{}, + find_concat_zero_element_inputs{}, find_transpose_slice{}, find_slice_transpose{}, find_unary_shape_transforms{}, From dfff5a5efef97a96e1329ce786b11333573f3fee Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 18 Dec 2025 13:33:14 -0600 Subject: [PATCH 03/17] Cleanup changes and remove literal pass --- src/simplify_reshapes.cpp | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 30c1dc212cc..dafe7ed0553 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -574,27 +574,6 @@ struct find_concat_multibroadcasts } }; -/*struct find_zero_element_literal -{ - auto matcher() const - { - return match::name("@literal"); - } - - void apply(module& m, const match::matcher_result& mr) const - { - auto lit = mr.result; - auto s = lit->get_shape(); - - if(s.elements() == 0) - { - std::cout << "Found zero element literal" << std::endl; - lit->debug_print(); - m.remove_instruction(lit); - } - } -}; */ - struct find_concat_zero_element_inputs { auto matcher() const @@ -623,16 +602,11 @@ struct find_concat_zero_element_inputs // Replace old concat with updated concat with updated inputs if(new_inputs.size() == 0) { - std::cout << "found instruction to remove" << std::endl; - ins->debug_print(); m.remove_instruction(ins); } else if (new_inputs.size() < inputs.size()) { - std::cout << "found instruction to replace" << std::endl; - ins->debug_print(); auto concat = m.insert_instruction(ins, op, new_inputs); - concat->debug_print(); m.replace_instruction(ins, concat); } } @@ -1486,7 +1460,6 @@ void simplify_reshapes::apply(module& m) const { m.repeat_while_changes(depth, [&] { match::find_matches(m, - //find_zero_element_literal{}, find_where_op{}, find_resize{}, find_nop_reshapes{}, From ff8983e38c02b0a22e68b410bbac41f7af5a206b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 18 Dec 2025 14:14:31 -0600 Subject: [PATCH 04/17] [AI Generated] Add comprehensive tests for find_concat_zero_element_inputs optimization Added 7 test cases to verify the find_concat_zero_element_inputs optimization in simplify_reshapes pass: - concat_zero_element_inputs_all_zero: Tests removal when all inputs are zero - concat_zero_element_inputs_some_zero: Tests filtering middle zero inputs - concat_zero_element_inputs_first_zero: Tests filtering first zero input - concat_zero_element_inputs_last_zero: Tests filtering last zero input - concat_zero_element_inputs_multiple_zero: Tests filtering multiple zeros - concat_zero_element_inputs_no_zero: Baseline test with no zero inputs - concat_zero_element_inputs_different_axis: Tests on different concat axis Tests cover all code paths in the optimization and verify that concat operations correctly filter out inputs with zero elements (shape.elements() == 0). --- test/simplify_reshapes_test.cpp | 200 ++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index ce8f1f50eec..91f9e1a836c 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3661,4 +3661,204 @@ TEST_CASE(conv_add_layernorm_conv) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(concat_zero_element_inputs_all_zero) +{ + // Test case where all inputs have zero elements - concat should be removed + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + // concat should be removed, but we still need a return value + // The optimization removes the concat, leaving the graph incomplete + // This test verifies the concat is removed + } + + // After optimization, concat should be removed + EXPECT(std::count_if(m1.begin(), m1.end(), [](auto ins) { return ins.name() == "concat"; }) == + 0); +} + +TEST_CASE(concat_zero_element_inputs_some_zero) +{ + // Test case where some inputs have zero elements - they should be filtered out + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m2.add_parameter("x", s1); + auto z = m2.add_parameter("z", s3); + // Only x and z should remain in the concat (y is filtered out) + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_first_zero) +{ + // Test case where the first input has zero elements + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto y = m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_last_zero) +{ + // Test case where the last input has zero elements + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_multiple_zero) +{ + // Test case where multiple inputs have zero elements + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto w = m1.add_parameter("w", s1); + auto x = m1.add_parameter("x", s2); + auto y = m1.add_parameter("y", s3); + auto z = m1.add_parameter("z", s4); + auto concat = + m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w, x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m2.add_parameter("x", s2); + auto z = m2.add_parameter("z", s4); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_no_zero) +{ + // Test case where no inputs have zero elements - concat should remain unchanged + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 7, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + auto original_m1 = m1; + run_pass(m1); + + // Module should remain unchanged + EXPECT(m1.sort() == original_m1.sort()); +} + +TEST_CASE(concat_zero_element_inputs_different_axis) +{ + // Test case with zero elements on a different axis + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; + auto x = m2.add_parameter("x", s1); + auto z = m2.add_parameter("z", s3); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 914330cfec7bd6db60e609be4826b3a53411b379 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 18 Dec 2025 14:29:03 -0600 Subject: [PATCH 05/17] Update format --- test/simplify_reshapes_test.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 91f9e1a836c..f5549fd6f4f 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3679,8 +3679,6 @@ TEST_CASE(concat_zero_element_inputs_all_zero) { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; - auto x = m2.add_parameter("x", s1); - auto y = m2.add_parameter("y", s2); // concat should be removed, but we still need a return value // The optimization removes the concat, leaving the graph incomplete // This test verifies the concat is removed From 8326e913f7f1e74d5e1fa56e534cbe55507e0f38 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 10:33:57 -0600 Subject: [PATCH 06/17] Fix tidy warning with size used as empty --- src/simplify_reshapes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index dafe7ed0553..bbf6fd2ef67 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -600,7 +600,7 @@ struct find_concat_zero_element_inputs // Replace old concat with updated concat with updated inputs - if(new_inputs.size() == 0) + if(new_inputs.empty() == 0) { m.remove_instruction(ins); } From fe972dc3fa991fff4fbfc59b2b885582a756ad99 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 10:34:17 -0600 Subject: [PATCH 07/17] Remove unused module m2 in test --- test/simplify_reshapes_test.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index f5549fd6f4f..3e5f1b067ce 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3675,15 +3675,6 @@ TEST_CASE(concat_zero_element_inputs_all_zero) } run_pass(m1); - migraphx::module m2; - { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; - // concat should be removed, but we still need a return value - // The optimization removes the concat, leaving the graph incomplete - // This test verifies the concat is removed - } - // After optimization, concat should be removed EXPECT(std::count_if(m1.begin(), m1.end(), [](auto ins) { return ins.name() == "concat"; }) == 0); From 31c7b8658eebdc1caa40f74565a5bcb8982cccb0 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 11:08:29 -0600 Subject: [PATCH 08/17] Updated changelog --- CHANGELOG.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dead3a4dbbf..3b368b58ca3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,24 @@ Full documentation for MIGraphX is available at [https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). + +## MIGraphX 2.16 for ROCm 7.2.1 + +### Added + +### Changed + +### Resolved Issues +* Fixed BF16/FP16 Quantization failure via zero element concats (#4512). + +### Optimized + +### Removed + + + + + ## MIGraphX 2.15 for ROCm 7.2.0 ### Added From dfe2b1e68905b5bc8845c08c99c280e1ff67af4b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 11:34:56 -0600 Subject: [PATCH 09/17] Cleanup matcher --- src/simplify_reshapes.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index bbf6fd2ef67..17d9924111b 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -585,8 +585,6 @@ struct find_concat_zero_element_inputs { auto ins = mr.result; auto inputs = ins->inputs(); - auto outs = ins->outputs(); - auto op = any_cast(ins->get_operator()); std::vector new_inputs; @@ -598,14 +596,14 @@ struct find_concat_zero_element_inputs [&](const auto& in) { return in->get_shape().elements() != 0; }, [&](const auto& in) { return in; }); - // Replace old concat with updated concat with updated inputs - if(new_inputs.empty() == 0) + if(new_inputs.empty()) { m.remove_instruction(ins); } else if (new_inputs.size() < inputs.size()) { + auto op = any_cast(ins->get_operator()); auto concat = m.insert_instruction(ins, op, new_inputs); m.replace_instruction(ins, concat); } From 5e84c77135cb8c66599a30a760b1add6fbc5c3ff Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:10:19 -0500 Subject: [PATCH 10/17] Update test/simplify_reshapes_test.cpp Co-authored-by: Charlie Lin --- test/simplify_reshapes_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 3e5f1b067ce..89b5a678e32 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3827,7 +3827,7 @@ TEST_CASE(concat_zero_element_inputs_different_axis) migraphx::module m1; { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {0, 2, 4}}; auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; auto x = m1.add_parameter("x", s1); auto y = m1.add_parameter("y", s2); From caf5e88531b4f641a1c1e84ed32a831b5626172a Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 15:25:20 -0600 Subject: [PATCH 11/17] Update matcher, tests and concat ot handle zero input values - For zero input element's concats we'll catch them with the matcher and remove those inputs - Ensure reshape tastes are passing - Make matcher more narrow to target only concats with zero input elements --- src/include/migraphx/op/concat.hpp | 6 ++++ src/simplify_reshapes.cpp | 4 +-- test/simplify_reshapes_test.cpp | 45 +++++++++++++++++++++++++++++- 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 527ef55bc1c..a811096878e 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -90,6 +90,12 @@ struct concat { if(ll != axis) { + // Skip if any input dimension is 0 since we'll optimize this out later + if(std::any_of(inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == 0; })) + { + continue; + } + if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == first_shape_lens[ll]; })) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 17d9924111b..8ca6ad63b50 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -578,7 +578,7 @@ struct find_concat_zero_element_inputs { auto matcher() const { - return match::name("concat"); + return match::name("concat")(match::any_of[match::inputs()](match::make_basic_pred_matcher([](const auto& ins) { return ins->get_shape().elements() != 0; }))); } void apply(module& m, const match::matcher_result& mr) const @@ -595,7 +595,7 @@ struct find_concat_zero_element_inputs std::back_inserter(new_inputs), [&](const auto& in) { return in->get_shape().elements() != 0; }, [&](const auto& in) { return in; }); - + // Replace old concat with updated concat with updated inputs if(new_inputs.empty()) { diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index 89b5a678e32..b2c17894e97 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3699,8 +3699,10 @@ TEST_CASE(concat_zero_element_inputs_some_zero) migraphx::module m2; { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; auto x = m2.add_parameter("x", s1); + m2.add_parameter("y", s2); auto z = m2.add_parameter("z", s3); // Only x and z should remain in the concat (y is filtered out) auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); @@ -3728,8 +3730,10 @@ TEST_CASE(concat_zero_element_inputs_first_zero) migraphx::module m2; { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + m2.add_parameter("x", s1); auto y = m2.add_parameter("y", s2); auto z = m2.add_parameter("z", s3); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, z); @@ -3759,8 +3763,10 @@ TEST_CASE(concat_zero_element_inputs_last_zero) { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto x = m2.add_parameter("x", s1); auto y = m2.add_parameter("y", s2); + m2.add_parameter("z", s3); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); m2.add_return({concat}); } @@ -3789,9 +3795,13 @@ TEST_CASE(concat_zero_element_inputs_multiple_zero) migraphx::module m2; { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + m2.add_parameter("w", s1); auto x = m2.add_parameter("x", s2); + m2.add_parameter("y", s3); auto z = m2.add_parameter("z", s4); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); m2.add_return({concat}); @@ -3827,7 +3837,7 @@ TEST_CASE(concat_zero_element_inputs_different_axis) migraphx::module m1; { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {0, 2, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; auto x = m1.add_parameter("x", s1); auto y = m1.add_parameter("y", s2); @@ -3840,8 +3850,10 @@ TEST_CASE(concat_zero_element_inputs_different_axis) migraphx::module m2; { auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; auto x = m2.add_parameter("x", s1); + m2.add_parameter("y", s2); auto z = m2.add_parameter("z", s3); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, z); m2.add_return({concat}); @@ -3850,4 +3862,35 @@ TEST_CASE(concat_zero_element_inputs_different_axis) EXPECT(m1.sort() == m2.sort()); } +TEST_CASE(concat_zero_element_one_input) +{ + // Test case with zero elements on a different axis + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + m2.add_parameter("x", s1); + m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); + + m2.add_return({z}); + } + + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); } From 33a7a578e91dbb61ddcf0fee8ac21499bceb4cca Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 15:31:22 -0600 Subject: [PATCH 12/17] Stub out all zero case test - Needs further discussion --- test/simplify_reshapes_test.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index b2c17894e97..c8387157077 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3661,7 +3661,8 @@ TEST_CASE(conv_add_layernorm_conv) EXPECT(m1.sort() == m2.sort()); } -TEST_CASE(concat_zero_element_inputs_all_zero) +//TODO: Add test case for all zero inputs and determine how we want to hanlde this +/*TEST_CASE(concat_zero_element_inputs_all_zero) { // Test case where all inputs have zero elements - concat should be removed migraphx::module m1; @@ -3678,7 +3679,7 @@ TEST_CASE(concat_zero_element_inputs_all_zero) // After optimization, concat should be removed EXPECT(std::count_if(m1.begin(), m1.end(), [](auto ins) { return ins.name() == "concat"; }) == 0); -} +} */ TEST_CASE(concat_zero_element_inputs_some_zero) { From 242e0bd1560e9172ffea56806146c6556dad8f4b Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 15:33:07 -0600 Subject: [PATCH 13/17] Format --- src/simplify_reshapes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 8ca6ad63b50..fd9f3e1abb0 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -601,7 +601,7 @@ struct find_concat_zero_element_inputs { m.remove_instruction(ins); } - else if (new_inputs.size() < inputs.size()) + else if(new_inputs.size() < inputs.size()) { auto op = any_cast(ins->get_operator()); auto concat = m.insert_instruction(ins, op, new_inputs); From 6a1dfe621f506fe52a161d7e5b492973da166112 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 15:35:55 -0600 Subject: [PATCH 14/17] Handle new_inputs == 1 case to make this a simple noop for the concat --- src/simplify_reshapes.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index fd9f3e1abb0..5e6795b9bc7 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -603,9 +603,16 @@ struct find_concat_zero_element_inputs } else if(new_inputs.size() < inputs.size()) { - auto op = any_cast(ins->get_operator()); - auto concat = m.insert_instruction(ins, op, new_inputs); - m.replace_instruction(ins, concat); + if(new_inputs.size() == 1) + { + m.replace_instruction(ins, new_inputs.front()); + } + else + { + auto op = any_cast(ins->get_operator()); + auto concat = m.insert_instruction(ins, op, new_inputs); + m.replace_instruction(ins, concat); + } } } }; From 0ac321fb057405143c3c7b96da775d88bfe73204 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 15:38:52 -0600 Subject: [PATCH 15/17] Update Changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b368b58ca3..a5e25291e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,13 +4,13 @@ Full documentation for MIGraphX is available at [https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). -## MIGraphX 2.16 for ROCm 7.2.1 +## Develop ### Added ### Changed -### Resolved Issues +### Resolved issues * Fixed BF16/FP16 Quantization failure via zero element concats (#4512). ### Optimized From efd04b5e9ffa1090753ad53f08295afb5cd336c4 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 16:22:21 -0600 Subject: [PATCH 16/17] Update license for concat op --- src/include/migraphx/op/concat.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index a811096878e..cce136185bb 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 1a97049c7687cd5fdd7bed6003f24c7292ec5df4 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 16:47:17 -0600 Subject: [PATCH 17/17] Format --- src/include/migraphx/op/concat.hpp | 3 +- src/simplify_reshapes.cpp | 13 ++-- test/simplify_reshapes_test.cpp | 107 ++++++++++++++--------------- 3 files changed, 62 insertions(+), 61 deletions(-) diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index cce136185bb..8b97f669622 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -91,7 +91,8 @@ struct concat if(ll != axis) { // Skip if any input dimension is 0 since we'll optimize this out later - if(std::any_of(inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == 0; })) + if(std::any_of( + inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == 0; })) { continue; } diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 5e6795b9bc7..390c0e1805f 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -578,9 +578,10 @@ struct find_concat_zero_element_inputs { auto matcher() const { - return match::name("concat")(match::any_of[match::inputs()](match::make_basic_pred_matcher([](const auto& ins) { return ins->get_shape().elements() != 0; }))); + return match::name("concat")(match::any_of[match::inputs()](match::make_basic_pred_matcher( + [](const auto& ins) { return ins->get_shape().elements() != 0; }))); } - + void apply(module& m, const match::matcher_result& mr) const { auto ins = mr.result; @@ -595,7 +596,7 @@ struct find_concat_zero_element_inputs std::back_inserter(new_inputs), [&](const auto& in) { return in->get_shape().elements() != 0; }, [&](const auto& in) { return in; }); - + // Replace old concat with updated concat with updated inputs if(new_inputs.empty()) { @@ -609,9 +610,9 @@ struct find_concat_zero_element_inputs } else { - auto op = any_cast(ins->get_operator()); - auto concat = m.insert_instruction(ins, op, new_inputs); - m.replace_instruction(ins, concat); + auto op = any_cast(ins->get_operator()); + auto concat = m.insert_instruction(ins, op, new_inputs); + m.replace_instruction(ins, concat); } } } diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index c8387157077..4c4723e0d84 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3661,7 +3661,7 @@ TEST_CASE(conv_add_layernorm_conv) EXPECT(m1.sort() == m2.sort()); } -//TODO: Add test case for all zero inputs and determine how we want to hanlde this +// TODO: Add test case for all zero inputs and determine how we want to hanlde this /*TEST_CASE(concat_zero_element_inputs_all_zero) { // Test case where all inputs have zero elements - concat should be removed @@ -3686,12 +3686,12 @@ TEST_CASE(concat_zero_element_inputs_some_zero) // Test case where some inputs have zero elements - they should be filtered out migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; - auto x = m1.add_parameter("x", s1); - auto y = m1.add_parameter("y", s2); - auto z = m1.add_parameter("z", s3); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); m1.add_return({concat}); } @@ -3704,7 +3704,7 @@ TEST_CASE(concat_zero_element_inputs_some_zero) auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; auto x = m2.add_parameter("x", s1); m2.add_parameter("y", s2); - auto z = m2.add_parameter("z", s3); + auto z = m2.add_parameter("z", s3); // Only x and z should remain in the concat (y is filtered out) auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); m2.add_return({concat}); @@ -3718,12 +3718,12 @@ TEST_CASE(concat_zero_element_inputs_first_zero) // Test case where the first input has zero elements migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; - auto x = m1.add_parameter("x", s1); - auto y = m1.add_parameter("y", s2); - auto z = m1.add_parameter("z", s3); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); m1.add_return({concat}); } @@ -3735,8 +3735,8 @@ TEST_CASE(concat_zero_element_inputs_first_zero) auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; m2.add_parameter("x", s1); - auto y = m2.add_parameter("y", s2); - auto z = m2.add_parameter("z", s3); + auto y = m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, z); m2.add_return({concat}); } @@ -3749,12 +3749,12 @@ TEST_CASE(concat_zero_element_inputs_last_zero) // Test case where the last input has zero elements migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; - auto x = m1.add_parameter("x", s1); - auto y = m1.add_parameter("y", s2); - auto z = m1.add_parameter("z", s3); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); m1.add_return({concat}); } @@ -3780,16 +3780,15 @@ TEST_CASE(concat_zero_element_inputs_multiple_zero) // Test case where multiple inputs have zero elements migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; - auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; - auto w = m1.add_parameter("w", s1); - auto x = m1.add_parameter("x", s2); - auto y = m1.add_parameter("y", s3); - auto z = m1.add_parameter("z", s4); - auto concat = - m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w, x, y, z); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto w = m1.add_parameter("w", s1); + auto x = m1.add_parameter("x", s2); + auto y = m1.add_parameter("y", s3); + auto z = m1.add_parameter("z", s4); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w, x, y, z); m1.add_return({concat}); } run_pass(m1); @@ -3801,9 +3800,9 @@ TEST_CASE(concat_zero_element_inputs_multiple_zero) auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; m2.add_parameter("w", s1); - auto x = m2.add_parameter("x", s2); + auto x = m2.add_parameter("x", s2); m2.add_parameter("y", s3); - auto z = m2.add_parameter("z", s4); + auto z = m2.add_parameter("z", s4); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); m2.add_return({concat}); } @@ -3816,12 +3815,12 @@ TEST_CASE(concat_zero_element_inputs_no_zero) // Test case where no inputs have zero elements - concat should remain unchanged migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 7, 4}}; - auto x = m1.add_parameter("x", s1); - auto y = m1.add_parameter("y", s2); - auto z = m1.add_parameter("z", s3); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 7, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); m1.add_return({concat}); } @@ -3837,12 +3836,12 @@ TEST_CASE(concat_zero_element_inputs_different_axis) // Test case with zero elements on a different axis migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; - auto x = m1.add_parameter("x", s1); - auto y = m1.add_parameter("y", s2); - auto z = m1.add_parameter("z", s3); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z); m1.add_return({concat}); } @@ -3855,7 +3854,7 @@ TEST_CASE(concat_zero_element_inputs_different_axis) auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; auto x = m2.add_parameter("x", s1); m2.add_parameter("y", s2); - auto z = m2.add_parameter("z", s3); + auto z = m2.add_parameter("z", s3); auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, z); m2.add_return({concat}); } @@ -3868,12 +3867,12 @@ TEST_CASE(concat_zero_element_one_input) // Test case with zero elements on a different axis migraphx::module m1; { - auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; - auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; - auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; - auto x = m1.add_parameter("x", s1); - auto y = m1.add_parameter("y", s2); - auto z = m1.add_parameter("z", s3); + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z); m1.add_return({concat}); } @@ -3887,7 +3886,7 @@ TEST_CASE(concat_zero_element_one_input) m2.add_parameter("x", s1); m2.add_parameter("y", s2); auto z = m2.add_parameter("z", s3); - + m2.add_return({z}); }