diff --git a/CHANGELOG.md b/CHANGELOG.md index dead3a4dbbf..a5e25291e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,24 @@ Full documentation for MIGraphX is available at [https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/](https://rocmdocs.amd.com/projects/AMDMIGraphX/en/latest/). + +## Develop + +### Added + +### Changed + +### Resolved issues +* Fixed BF16/FP16 Quantization failure via zero element concats (#4512). + +### Optimized + +### Removed + + + + + ## MIGraphX 2.15 for ROCm 7.2.0 ### Added diff --git a/src/include/migraphx/op/concat.hpp b/src/include/migraphx/op/concat.hpp index 527ef55bc1c..8b97f669622 100644 --- a/src/include/migraphx/op/concat.hpp +++ b/src/include/migraphx/op/concat.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -90,6 +90,13 @@ struct concat { if(ll != axis) { + // Skip if any input dimension is 0 since we'll optimize this out later + if(std::any_of( + inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == 0; })) + { + continue; + } + if(not std::all_of(inputs.begin(), inputs.end(), [&](auto s) { return s.lens()[ll] == first_shape_lens[ll]; })) diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 51b5e199db5..390c0e1805f 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -574,6 +574,50 @@ struct find_concat_multibroadcasts } }; +struct find_concat_zero_element_inputs +{ + auto matcher() const + { + return match::name("concat")(match::any_of[match::inputs()](match::make_basic_pred_matcher( + [](const auto& ins) { return ins->get_shape().elements() != 0; }))); + } + + void apply(module& m, const match::matcher_result& mr) const + { + auto ins = mr.result; + auto inputs = ins->inputs(); + + std::vector new_inputs; + + // Filter inputs that don't have zero element shapes + migraphx::transform_if( + inputs.begin(), + inputs.end(), + std::back_inserter(new_inputs), + [&](const auto& in) { return in->get_shape().elements() != 0; }, + [&](const auto& in) { return in; }); + + // Replace old concat with updated concat with updated inputs + if(new_inputs.empty()) + { + m.remove_instruction(ins); + } + else if(new_inputs.size() < inputs.size()) + { + if(new_inputs.size() == 1) + { + m.replace_instruction(ins, new_inputs.front()); + } + else + { + auto op = any_cast(ins->get_operator()); + auto concat = m.insert_instruction(ins, op, new_inputs); + m.replace_instruction(ins, concat); + } + } + } +}; + struct find_concat_slice { auto matcher() const @@ -1434,6 +1478,7 @@ void simplify_reshapes::apply(module& m) const find_concat_multibroadcasts{}, find_nested_slice{}, find_nested_concat{}, + find_concat_zero_element_inputs{}, find_transpose_slice{}, find_slice_transpose{}, find_unary_shape_transforms{}, diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index ce8f1f50eec..4c4723e0d84 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -3661,4 +3661,236 @@ TEST_CASE(conv_add_layernorm_conv) EXPECT(m1.sort() == m2.sort()); } +// TODO: Add test case for all zero inputs and determine how we want to hanlde this +/*TEST_CASE(concat_zero_element_inputs_all_zero) +{ + // Test case where all inputs have zero elements - concat should be removed + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 3}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + m1.add_return({concat}); + } + run_pass(m1); + + // After optimization, concat should be removed + EXPECT(std::count_if(m1.begin(), m1.end(), [](auto ins) { return ins.name() == "concat"; }) == + 0); +} */ + +TEST_CASE(concat_zero_element_inputs_some_zero) +{ + // Test case where some inputs have zero elements - they should be filtered out + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m2.add_parameter("x", s1); + m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); + // Only x and z should remain in the concat (y is filtered out) + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_first_zero) +{ + // Test case where the first input has zero elements + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), y, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_last_zero) +{ + // Test case where the last input has zero elements + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto x = m2.add_parameter("x", s1); + auto y = m2.add_parameter("y", s2); + m2.add_parameter("z", s3); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_multiple_zero) +{ + // Test case where multiple inputs have zero elements + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto w = m1.add_parameter("w", s1); + auto x = m1.add_parameter("x", s2); + auto y = m1.add_parameter("y", s3); + auto z = m1.add_parameter("z", s4); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), w, x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s4 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + m2.add_parameter("w", s1); + auto x = m2.add_parameter("x", s2); + m2.add_parameter("y", s3); + auto z = m2.add_parameter("z", s4); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_inputs_no_zero) +{ + // Test case where no inputs have zero elements - concat should remain unchanged + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 5, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 7, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z); + m1.add_return({concat}); + } + auto original_m1 = m1; + run_pass(m1); + + // Module should remain unchanged + EXPECT(m1.sort() == original_m1.sort()); +} + +TEST_CASE(concat_zero_element_inputs_different_axis) +{ + // Test case with zero elements on a different axis + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 0, 4}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}; + auto x = m2.add_parameter("x", s1); + m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, z); + m2.add_return({concat}); + } + + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(concat_zero_element_one_input) +{ + // Test case with zero elements on a different axis + migraphx::module m1; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + auto x = m1.add_parameter("x", s1); + auto y = m1.add_parameter("y", s2); + auto z = m1.add_parameter("z", s3); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), x, y, z); + m1.add_return({concat}); + } + run_pass(m1); + + migraphx::module m2; + { + auto s1 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s2 = migraphx::shape{migraphx::shape::float_type, {2, 3, 0}}; + auto s3 = migraphx::shape{migraphx::shape::float_type, {2, 3, 4}}; + m2.add_parameter("x", s1); + m2.add_parameter("y", s2); + auto z = m2.add_parameter("z", s3); + + m2.add_return({z}); + } + + EXPECT(m1.sort() == m2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }