From 3a9cff679e99dfb44a6c9ef0a0d5d0e5eb3ef78d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 18 Nov 2025 23:16:37 -0600 Subject: [PATCH 01/30] initial --- src/fuse_attention.cpp | 53 ++++++++++-- src/include/migraphx/split_factor.hpp | 116 ++++++++++++++++++++++++++ src/rewrite_topk.cpp | 19 +---- src/targets/gpu/jit/reduce.cpp | 17 +--- 4 files changed, 168 insertions(+), 37 deletions(-) create mode 100644 src/include/migraphx/split_factor.hpp diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 7fc29634c4f..0d3004afd98 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -37,7 +38,12 @@ inline namespace MIGRAPHX_INLINE_NS { namespace { +// Environment variables for flash decoding configuration MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_AUTO_SPLIT); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); std::size_t get_num_splits() { return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); } @@ -216,8 +222,9 @@ struct find_attention struct find_flash_decoding { - // number of groups. User-provided for now + // number of groups (0 means auto-calculate) std::size_t groups; + bool use_auto_split = false; auto matcher() const { @@ -449,8 +456,30 @@ struct find_flash_decoding assert(k_param->name() == "@param" and "K should be a parameter"); assert(v_param->name() == "@param" and "V should be a parameter"); + // Get sequence length from K shape + auto k_shape = k_param->get_shape(); + std::size_t sequence_length = k_shape.lens().back(); + + // Determine actual number of splits to use + std::size_t actual_groups = groups; + if(use_auto_split) + { + // Check if sequence length meets threshold for flash decoding + std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 1024); + if(sequence_length < threshold) + return; + + std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 256); + std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 32); + actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); + + // Skip if auto-calculation determines no splitting needed + if(actual_groups <= 1) + return; + } + // check if N dimension is evenly divisible by num_splits - if(k_param->get_shape().lens().back() % groups != 0) + if(sequence_length % actual_groups != 0) return; // get Q, V, K shapes from gemms @@ -462,6 +491,9 @@ struct find_flash_decoding // create mapping from submodule params to main module inputs auto group_inputs = attn_group_ins->inputs(); auto map_param_to_main = map_submod_params_to_inputs(submod, group_inputs); + + // Use actual_groups for all subsequent operations + groups = actual_groups; // get actual Q, K, V instructions from main module auto q = map_param_to_main.at(q_param); @@ -797,20 +829,29 @@ void fuse_attention::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); } + // Determine split strategy + bool use_auto_split = enabled(MIGRAPHX_FLASH_DECODING_AUTO_SPLIT{}); std::size_t num_splits = 0; + if(flash_decoding_num_splits.has_value()) { // Use the value from the constructor (for testing) num_splits = *flash_decoding_num_splits; + use_auto_split = false; // Disable auto for testing } - else + else if(!use_auto_split) { - // Default behavior: read from the env var (for non-test use) + // Legacy behavior: read from the env var (for non-test use) num_splits = get_num_splits(); } - if(num_splits > 1) + + // Apply flash decoding with either manual or automatic splitting + if(use_auto_split || num_splits > 1) { - match::find_matches(mpm, find_flash_decoding{.groups = num_splits}); + match::find_matches(mpm, find_flash_decoding{ + .groups = num_splits, + .use_auto_split = use_auto_split + }); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp new file mode 100644 index 00000000000..188d92b40ea --- /dev/null +++ b/src/include/migraphx/split_factor.hpp @@ -0,0 +1,116 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_SPLIT_FACTOR_HPP +#define MIGRAPHX_GUARD_SPLIT_FACTOR_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +/** + * Calculate split factor for a dimension to make it less than min_size. + * + * This function finds the largest divisor that can divide the dimension + * to make it less than min_size. It uses prime factors [2, 3, 5, 7, 11] + * to find good divisors that work well for parallel execution. + * + * Used by: + * - rewrite_topk: Splits large topk operations for better performance + * - flash_decoding: Splits attention sequence dimension for parallelization + * - split_reduce (GPU JIT): Splits reduction operations + * + * @param r The dimension size to split (will be modified to remaining size) + * @param min_size The minimum size threshold + * @return The split factor (number of groups) + */ +inline std::size_t split_dim(std::size_t& r, std::size_t min_size) +{ + std::size_t n = 1; + auto factors = make_array(2, 3, 5, 7, 11); + while(r > min_size) + { + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { + return r % d == 0; + }); + if(it == factors.end()) + break; + r /= *it; + n *= *it; + } + return n; +} + +/** + * Calculate split factor with maximum splits constraint. + * + * Similar to split_dim but also respects a maximum number of splits. + * Useful when there's a limit on parallelization due to hardware constraints. + * + * @param dimension The dimension size to split + * @param min_size Minimum size per chunk after splitting + * @param max_splits Maximum number of splits allowed (0 = no limit) + * @return The split factor that respects both constraints + */ +inline std::size_t split_dim_with_max(std::size_t dimension, + std::size_t min_size, + std::size_t max_splits = 0) +{ + // Make a copy since split_dim modifies the value + std::size_t remaining = dimension; + std::size_t num_splits = split_dim(remaining, min_size); + + // If no max constraint or already within limit, return as is + if(max_splits == 0 || num_splits <= max_splits) + return num_splits; + + // Reduce splits to respect max_splits constraint + auto factors = make_array(2, 3, 5, 7, 11); + while(num_splits > max_splits) + { + // Remove the smallest prime factor to reduce splits + for(auto factor : factors) + { + if(num_splits % factor == 0) + { + num_splits /= factor; + remaining *= factor; + break; + } + } + // Safety check to avoid infinite loop + if(num_splits > max_splits && num_splits < 2) + break; + } + + return num_splits; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_SPLIT_FACTOR_HPP diff --git a/src/rewrite_topk.cpp b/src/rewrite_topk.cpp index 55fd0e2bb79..e94f37afdc7 100644 --- a/src/rewrite_topk.cpp +++ b/src/rewrite_topk.cpp @@ -26,7 +26,7 @@ #include #include #include -#include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -38,23 +38,6 @@ struct find_large_topk std::size_t n_threshold = 0; auto matcher() const { return match::name("topk"); } - static std::size_t split_dim(std::size_t& r, std::size_t min_size) - { - std::size_t n = 1; - auto factors = make_array(2, 3, 5, 7, 11); - while(r > min_size) - { - // NOLINTNEXTLINE(readability-qualified-auto) - auto it = - std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); - if(it == factors.end()) - break; - r /= *it; - n *= *it; - } - return n; - } - void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index c8fcf89b04c..4da8c1ad49c 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include namespace migraphx { @@ -175,18 +175,9 @@ static std::vector split_reduce(const std::vector& inputs, assert(faxis < reduce_shape.lens().size()); - std::size_t n = 1; - auto r = input_shape.lens()[faxis]; - auto factors = make_array(2, 3, 5, 7, 11); - while(r > min_size and n < max_splits) - { - // NOLINTNEXTLINE(readability-qualified-auto) - auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); - if(it == factors.end()) - break; - r /= *it; - n *= *it; - } + auto r = input_shape.lens()[faxis]; + // Use common split_dim_with_max function from split_factor.hpp + std::size_t n = split_dim_with_max(r, min_size, max_splits); assert(n != 1); std::transform( inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { From 8259f9a521851df2bfed5aa53b1e749a5caaed8c Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 24 Nov 2025 00:41:57 -0600 Subject: [PATCH 02/30] auto split --- docs/reference/MIGraphX-dev-env-vars.rst | 44 ++++++++- requirements.txt | 2 +- src/fuse_attention.cpp | 119 ++++++++++++++++++----- src/include/migraphx/fuse_attention.hpp | 2 +- src/include/migraphx/split_factor.hpp | 1 + test/fuse_attention.cpp | 6 +- 6 files changed, 138 insertions(+), 36 deletions(-) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index ac86474ddcf..1c0bc531c5e 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -160,13 +160,49 @@ Model performance tunable variables change the compilation behavior of a model. | Default: Split-k performance configurations are turned off. + * - | ``MIGRAPHX_FLASH_DECODING_ENABLED`` + - | When set, flash decoding optimization for attention fusion is enabled, which splits the key-value sequence dimension for improved performance on long sequences. + - | ``1``: Enables flash decoding optimization. + | ``0``: Disables flash decoding optimization. + + | Default: ``0`` (disabled). + * - | ``MIGRAPHX_FLASH_DECODING_NUM_SPLITS`` - | Turns on flash decoding for attention fusion and sets the number of splits along the key-value sequence dimension. + | Sets the number of splits along the key-value sequence dimension when flash decoding is enabled. + + - | ``0`` or negative: Automatically calculates optimal splits based on sequence length. + | ``N`` (where N > 0): Uses exactly N splits along the key-value sequence dimension. + + | Default: ``0`` (automatic calculation). + + | Note: This variable is only used when ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. + + * - | ``MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE`` + | Sets the minimum chunk size for automatic split calculation in flash decoding. - - | ``0``: Flash decoding is turned off (i.e., number of splits is 0). - | ``N`` (where N > 1): Enables flash decoding with N splits along the key-value sequence dimension. For example, ``2`` enables flash decoding with 2 splits, ``4`` with 4 splits, etc. + - | Takes a positive integer representing the minimum size of each chunk after splitting. + + | Default: ``256``. + + | Note: Only used when automatic split calculation is enabled. - | Default: flash decoding is turned off. + * - | ``MIGRAPHX_FLASH_DECODING_MAX_SPLITS`` + | Sets the maximum number of splits allowed during automatic split calculation. + + - | Takes a positive integer representing the maximum number of splits. + + | Default: ``32``. + + | Note: Only used when automatic split calculation is enabled. + + * - | ``MIGRAPHX_FLASH_DECODING_THRESHOLD`` + | Sets the minimum sequence length threshold for flash decoding to be applied. + + - | Takes a positive integer. Flash decoding is only applied when the sequence length is greater than or equal to this threshold. + + | Default: ``1024``. + + | Note: Only used when automatic split calculation is enabled. * - | ``MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT`` | When set, FP16 is not converted to FP32 in the ``InstanceNormalization`` ONNX operator. diff --git a/requirements.txt b/requirements.txt index b2cdb44557c..87071e0139b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@3a034fd4fa6c25ada90f1786700748b2f58aaf85 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off \ No newline at end of file +ROCm/rocMLIR@23aa5615d5a942e15f734a4bc5144b8e3b07ad16 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 0d3004afd98..aff5567ec02 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include #include @@ -38,15 +38,38 @@ inline namespace MIGRAPHX_INLINE_NS { namespace { -// Environment variables for flash decoding configuration +// env vars for flash decoding configuration +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_ENABLED); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_AUTO_SPLIT); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); +// Check for MLIR attention ops usage +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); + +bool is_mlir_attention_enabled() +{ + auto ops = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); + return ops.find("attention") != std::string::npos; +} + +bool is_flash_decoding_enabled() +{ + // flash decoding is enabled if explicitly enabled + return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); +} + std::size_t get_num_splits() { return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); } +// calculate optimal flash decoding splits +inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, + std::size_t min_chunk_size, + std::size_t max_splits) +{ + return split_dim_with_max(sequence_length, min_chunk_size, max_splits); +} + // TODO: Write this in matcher.hpp as a general matcher for iterating through inputs inline auto pointwise_inputs() { @@ -222,9 +245,9 @@ struct find_attention struct find_flash_decoding { - // number of groups (0 means auto-calculate) + // number of groups (0 or negative means auto-calculate) std::size_t groups; - bool use_auto_split = false; + bool auto_calculate = false; auto matcher() const { @@ -274,7 +297,8 @@ struct find_flash_decoding std::vector v_shape; // final V shape: [B, G, N/G, D] }; - transformed_shapes_result get_transformed_shapes(const std::vector& input_shapes) const + transformed_shapes_result get_transformed_shapes(const std::vector& input_shapes, + std::size_t num_groups) const { assert(input_shapes.size() == 3 and "Expected Q, K, V shapes"); @@ -286,7 +310,7 @@ struct find_flash_decoding // 4D: Q_lens = [B, H, M, k] size_t ndim = q_lens.size(); size_t n = k_lens[ndim - 1]; - size_t g = groups; + size_t g = num_groups; // TODO: handle uneven splits; this is caught in `apply` for now assert(n % g == 0 and @@ -434,6 +458,7 @@ struct find_flash_decoding void apply(module_pass_manager& mpm, const match::matcher_result& r) const { + std::cout << "apply flash decoding" << std::endl; auto& mm = mpm.get_module(); auto attn_group_ins = r.instructions["group"]; auto* submod = attn_group_ins->module_inputs().front(); @@ -462,15 +487,17 @@ struct find_flash_decoding // Determine actual number of splits to use std::size_t actual_groups = groups; - if(use_auto_split) + if(auto_calculate || groups <= 0) { - // Check if sequence length meets threshold for flash decoding + // Auto-calculate the optimal number of splits + std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 256); + std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 32); std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 1024); + + // Check if sequence length meets threshold for flash decoding if(sequence_length < threshold) return; - - std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 256); - std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 32); + actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); // Skip if auto-calculation determines no splitting needed @@ -478,6 +505,10 @@ struct find_flash_decoding return; } + // Skip if no actual splitting (num_splits must be > 1) + if(actual_groups <= 1) + return; + // check if N dimension is evenly divisible by num_splits if(sequence_length % actual_groups != 0) return; @@ -486,14 +517,11 @@ struct find_flash_decoding auto qkv_shapes = get_qkv_shapes(q_param, k_param, v_param); // check shapes are ok and get flash decoding transformed shapes (Q', V', K') - auto transform_info = get_transformed_shapes(qkv_shapes); + auto transform_info = get_transformed_shapes(qkv_shapes, actual_groups); // create mapping from submodule params to main module inputs auto group_inputs = attn_group_ins->inputs(); auto map_param_to_main = map_submod_params_to_inputs(submod, group_inputs); - - // Use actual_groups for all subsequent operations - groups = actual_groups; // get actual Q, K, V instructions from main module auto q = map_param_to_main.at(q_param); @@ -822,35 +850,72 @@ void fuse_attention::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); // Only fuse plain attention when requested - if(attn_enabled) + if(attn_enabled || is_mlir_attention_enabled()) { + std::cout << "fuse plain attention" << std::endl; + mpm.get_module().debug_print(); match::find_matches(mpm, find_attention{.counter = &counter}); mpm.get_module().sort(); mpm.run_pass(dead_code_elimination{}); + mpm.get_module().debug_print(); } - // Determine split strategy - bool use_auto_split = enabled(MIGRAPHX_FLASH_DECODING_AUTO_SPLIT{}); + // Check if flash decoding is enabled + bool flash_enabled = false; std::size_t num_splits = 0; + bool auto_calculate = false; + // for testing, check if the number of splits has been explicitly set if(flash_decoding_num_splits.has_value()) { - // Use the value from the constructor (for testing) - num_splits = *flash_decoding_num_splits; - use_auto_split = false; // Disable auto for testing + std::cout << "flash decoding is explicitly enabled via constructor" << std::endl; + std::cout << "flash_decoding_num_splits: " << *flash_decoding_num_splits << std::endl; + flash_enabled = true; + // constructor value provided (for testing) - consider it enabled if > 0 + if(*flash_decoding_num_splits > 0) + { + num_splits = *flash_decoding_num_splits; + auto_calculate = false; + } + else + { + // constructor provided 0 or negative - enable with auto-calculation + num_splits = 0; + auto_calculate = true; + } } - else if(!use_auto_split) + else if(is_flash_decoding_enabled()) { - // Legacy behavior: read from the env var (for non-test use) + std::cout << "flash decoding is explicitly enabled via environment variable" << std::endl; + // flash decoding is explicitly enabled via environment variable + flash_enabled = true; + + // check if user specified number of splits num_splits = get_num_splits(); + std::cout << "num_splits: " << num_splits << std::endl; + if(num_splits > 0) + { + // user specified a positive number of splits - use it + auto_calculate = false; + } + else + { + // user didn't specify or specified 0/negative - auto-calculate + num_splits = 0; + auto_calculate = true; + } } - // Apply flash decoding with either manual or automatic splitting - if(use_auto_split || num_splits > 1) + // Apply flash decoding if enabled + if(flash_enabled) { + std::cout << "flash_enabled: " << flash_enabled << std::endl; + std::cout << "num_splits: " << num_splits << std::endl; + std::cout << "auto_calculate: " << auto_calculate << std::endl; + mpm.get_module().debug_print(); match::find_matches(mpm, find_flash_decoding{ .groups = num_splits, - .use_auto_split = use_auto_split + .auto_calculate = auto_calculate }); mpm.run_pass(dead_code_elimination{}); } diff --git a/src/include/migraphx/fuse_attention.hpp b/src/include/migraphx/fuse_attention.hpp index fe4b05fb3c0..db07c9aac02 100644 --- a/src/include/migraphx/fuse_attention.hpp +++ b/src/include/migraphx/fuse_attention.hpp @@ -37,8 +37,8 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_attention { - std::optional flash_decoding_num_splits = std::nullopt; bool attn_enabled = false; + std::optional flash_decoding_num_splits = std::nullopt; std::string name() const { return "fuse_attention"; } void apply(module_pass_manager& mpm) const; diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index 188d92b40ea..b57db2d5257 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -66,6 +66,7 @@ inline std::size_t split_dim(std::size_t& r, std::size_t min_size) } /** + * TODO: lessen code reuse, save factor array * Calculate split factor with maximum splits constraint. * * Similar to split_dim but also respects a maximum number of splits. diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 27c92015364..41b145fc267 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -445,7 +445,7 @@ TEST_CASE(gemm_softmax_gemm_flash_decoding) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.flash_decoding_num_splits = 2, .attn_enabled = true}); + run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -546,7 +546,7 @@ TEST_CASE(flash_decoding_3d) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.flash_decoding_num_splits = 2, .attn_enabled = true}); + run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { @@ -661,7 +661,7 @@ TEST_CASE(flash_decoding_disabled) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.flash_decoding_num_splits = 0, .attn_enabled = true}); + run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 0}); // Expected result: only attention fusion, no flash decoding migraphx::program p2; From 78951196946092317dd0f3872705cfde01dcc83c Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 24 Nov 2025 01:01:17 -0600 Subject: [PATCH 03/30] blah --- test/verify/test_attention_flash_decoding_4d.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/verify/test_attention_flash_decoding_4d.cpp b/test/verify/test_attention_flash_decoding_4d.cpp index ffbe725b6e2..4ab98eeab5b 100644 --- a/test/verify/test_attention_flash_decoding_4d.cpp +++ b/test/verify/test_attention_flash_decoding_4d.cpp @@ -62,7 +62,6 @@ struct test_attention_flash_decoding_4d : verify_program; -// template struct test_attention_flash_decoding_4d; -// template struct test_attention_flash_decoding_4d; +template struct test_attention_flash_decoding_4d; +template struct test_attention_flash_decoding_4d; +template struct test_attention_flash_decoding_4d; From 4a9e91816185fa06d8faf89cfd773eca3cedf001 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 4 Dec 2025 11:31:19 -0600 Subject: [PATCH 04/30] make 3d test more interesting --- .../test_attention_flash_decoding_3d.cpp | 42 +++++++++++-------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/test/verify/test_attention_flash_decoding_3d.cpp b/test/verify/test_attention_flash_decoding_3d.cpp index a0aa41619d5..5471940e1f8 100644 --- a/test/verify/test_attention_flash_decoding_3d.cpp +++ b/test/verify/test_attention_flash_decoding_3d.cpp @@ -32,27 +32,35 @@ struct test_attention_flash_decoding_3d : verify_programadd_parameter("q", s_3d); - auto b = mm->add_parameter("k", s_3d); - auto b1 = mm->add_parameter("v", s_3d); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); - rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), + auto a = mm->add_parameter("q", q_shape); + auto b = mm->add_parameter("k", k_shape); + auto b1 = mm->add_parameter("v", v_shape); + + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {1, 64, 256} + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); // {1, 64, 1} + rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), rmax); - auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); - rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), - rsum); - auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // {1, 64, 256} + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // {1, 64, 256} + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); // {1, 64, 1} + rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), + rsum); // {1, 64, 256} + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // {1, 64, 256} + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // {1, 64, 32} mm->add_return({gemm2}); return p1; } From d237085726932faf206b95bd4283280d6ebba8c5 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 4 Dec 2025 14:56:34 -0600 Subject: [PATCH 05/30] remove rocmlir update since it is on mainline now --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 87071e0139b..b2cdb44557c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,4 +29,4 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5 sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@23aa5615d5a942e15f734a4bc5144b8e3b07ad16 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/rocMLIR@3a034fd4fa6c25ada90f1786700748b2f58aaf85 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off \ No newline at end of file From b7e9f22340b9639413cd1923c67b12812371ac48 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 00:45:46 -0600 Subject: [PATCH 06/30] remove comments and clean up --- src/fuse_attention.cpp | 32 +++----------------------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index aff5567ec02..f2557e68c41 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -45,18 +45,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); -// Check for MLIR attention ops usage -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_MLIR_USE_SPECIFIC_OPS); - -bool is_mlir_attention_enabled() -{ - auto ops = string_value_of(MIGRAPHX_MLIR_USE_SPECIFIC_OPS{}, ""); - return ops.find("attention") != std::string::npos; -} - bool is_flash_decoding_enabled() -{ - // flash decoding is enabled if explicitly enabled +{ return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); } @@ -458,7 +448,6 @@ struct find_flash_decoding void apply(module_pass_manager& mpm, const match::matcher_result& r) const { - std::cout << "apply flash decoding" << std::endl; auto& mm = mpm.get_module(); auto attn_group_ins = r.instructions["group"]; auto* submod = attn_group_ins->module_inputs().front(); @@ -850,17 +839,14 @@ void fuse_attention::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); // Only fuse plain attention when requested - if(attn_enabled || is_mlir_attention_enabled()) + if(attn_enabled) { - std::cout << "fuse plain attention" << std::endl; - mpm.get_module().debug_print(); match::find_matches(mpm, find_attention{.counter = &counter}); mpm.get_module().sort(); mpm.run_pass(dead_code_elimination{}); - mpm.get_module().debug_print(); } - // Check if flash decoding is enabled + // Apply flash decoding if enabled bool flash_enabled = false; std::size_t num_splits = 0; bool auto_calculate = false; @@ -868,8 +854,6 @@ void fuse_attention::apply(module_pass_manager& mpm) const // for testing, check if the number of splits has been explicitly set if(flash_decoding_num_splits.has_value()) { - std::cout << "flash decoding is explicitly enabled via constructor" << std::endl; - std::cout << "flash_decoding_num_splits: " << *flash_decoding_num_splits << std::endl; flash_enabled = true; // constructor value provided (for testing) - consider it enabled if > 0 if(*flash_decoding_num_splits > 0) @@ -886,18 +870,13 @@ void fuse_attention::apply(module_pass_manager& mpm) const } else if(is_flash_decoding_enabled()) { - std::cout << "flash decoding is explicitly enabled via environment variable" << std::endl; // flash decoding is explicitly enabled via environment variable flash_enabled = true; // check if user specified number of splits num_splits = get_num_splits(); - std::cout << "num_splits: " << num_splits << std::endl; if(num_splits > 0) - { - // user specified a positive number of splits - use it auto_calculate = false; - } else { // user didn't specify or specified 0/negative - auto-calculate @@ -906,13 +885,8 @@ void fuse_attention::apply(module_pass_manager& mpm) const } } - // Apply flash decoding if enabled if(flash_enabled) { - std::cout << "flash_enabled: " << flash_enabled << std::endl; - std::cout << "num_splits: " << num_splits << std::endl; - std::cout << "auto_calculate: " << auto_calculate << std::endl; - mpm.get_module().debug_print(); match::find_matches(mpm, find_flash_decoding{ .groups = num_splits, .auto_calculate = auto_calculate From 2c20d78e6ca2eeb59b19cb19168f0934389c56c4 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 01:01:35 -0600 Subject: [PATCH 07/30] AIMIGRAPHX-289 AIMIGRAPHX-341 ; remove comment --- src/include/migraphx/split_factor.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index b57db2d5257..188d92b40ea 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -66,7 +66,6 @@ inline std::size_t split_dim(std::size_t& r, std::size_t min_size) } /** - * TODO: lessen code reuse, save factor array * Calculate split factor with maximum splits constraint. * * Similar to split_dim but also respects a maximum number of splits. From 6bda8f77d95c93da3a8cac31feb7ed4e8c31d048 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 01:09:20 -0600 Subject: [PATCH 08/30] format --- src/fuse_attention.cpp | 49 +++++++++---------- src/include/migraphx/split_factor.hpp | 33 ++++++------- .../test_attention_flash_decoding_3d.cpp | 19 ++++--- 3 files changed, 48 insertions(+), 53 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index f2557e68c41..54a2b15d2b5 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -45,17 +45,14 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); -bool is_flash_decoding_enabled() -{ - return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); -} +bool is_flash_decoding_enabled() { return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); } std::size_t get_num_splits() { return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); } // calculate optimal flash decoding splits -inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, - std::size_t min_chunk_size, - std::size_t max_splits) +inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, + std::size_t min_chunk_size, + std::size_t max_splits) { return split_dim_with_max(sequence_length, min_chunk_size, max_splits); } @@ -288,7 +285,7 @@ struct find_flash_decoding }; transformed_shapes_result get_transformed_shapes(const std::vector& input_shapes, - std::size_t num_groups) const + std::size_t num_groups) const { assert(input_shapes.size() == 3 and "Expected Q, K, V shapes"); @@ -471,33 +468,33 @@ struct find_flash_decoding assert(v_param->name() == "@param" and "V should be a parameter"); // Get sequence length from K shape - auto k_shape = k_param->get_shape(); + auto k_shape = k_param->get_shape(); std::size_t sequence_length = k_shape.lens().back(); - + // Determine actual number of splits to use std::size_t actual_groups = groups; if(auto_calculate || groups <= 0) { // Auto-calculate the optimal number of splits - std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 256); + std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 256); std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 32); - std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 1024); + std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 1024); // Check if sequence length meets threshold for flash decoding if(sequence_length < threshold) return; actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); - + // Skip if auto-calculation determines no splitting needed if(actual_groups <= 1) return; } - + // Skip if no actual splitting (num_splits must be > 1) if(actual_groups <= 1) return; - + // check if N dimension is evenly divisible by num_splits if(sequence_length % actual_groups != 0) return; @@ -847,10 +844,10 @@ void fuse_attention::apply(module_pass_manager& mpm) const } // Apply flash decoding if enabled - bool flash_enabled = false; + bool flash_enabled = false; std::size_t num_splits = 0; - bool auto_calculate = false; - + bool auto_calculate = false; + // for testing, check if the number of splits has been explicitly set if(flash_decoding_num_splits.has_value()) { @@ -858,13 +855,13 @@ void fuse_attention::apply(module_pass_manager& mpm) const // constructor value provided (for testing) - consider it enabled if > 0 if(*flash_decoding_num_splits > 0) { - num_splits = *flash_decoding_num_splits; + num_splits = *flash_decoding_num_splits; auto_calculate = false; } else { // constructor provided 0 or negative - enable with auto-calculation - num_splits = 0; + num_splits = 0; auto_calculate = true; } } @@ -872,7 +869,7 @@ void fuse_attention::apply(module_pass_manager& mpm) const { // flash decoding is explicitly enabled via environment variable flash_enabled = true; - + // check if user specified number of splits num_splits = get_num_splits(); if(num_splits > 0) @@ -880,17 +877,15 @@ void fuse_attention::apply(module_pass_manager& mpm) const else { // user didn't specify or specified 0/negative - auto-calculate - num_splits = 0; + num_splits = 0; auto_calculate = true; } } - + if(flash_enabled) { - match::find_matches(mpm, find_flash_decoding{ - .groups = num_splits, - .auto_calculate = auto_calculate - }); + match::find_matches( + mpm, find_flash_decoding{.groups = num_splits, .auto_calculate = auto_calculate}); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index 188d92b40ea..9d34c0cd0f1 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -33,16 +33,16 @@ inline namespace MIGRAPHX_INLINE_NS { /** * Calculate split factor for a dimension to make it less than min_size. - * - * This function finds the largest divisor that can divide the dimension - * to make it less than min_size. It uses prime factors [2, 3, 5, 7, 11] + * + * This function finds the largest divisor that can divide the dimension + * to make it less than min_size. It uses prime factors [2, 3, 5, 7, 11] * to find good divisors that work well for parallel execution. - * + * * Used by: * - rewrite_topk: Splits large topk operations for better performance * - flash_decoding: Splits attention sequence dimension for parallelization * - split_reduce (GPU JIT): Splits reduction operations - * + * * @param r The dimension size to split (will be modified to remaining size) * @param min_size The minimum size threshold * @return The split factor (number of groups) @@ -54,9 +54,7 @@ inline std::size_t split_dim(std::size_t& r, std::size_t min_size) while(r > min_size) { // NOLINTNEXTLINE(readability-qualified-auto) - auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { - return r % d == 0; - }); + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; r /= *it; @@ -67,27 +65,26 @@ inline std::size_t split_dim(std::size_t& r, std::size_t min_size) /** * Calculate split factor with maximum splits constraint. - * + * * Similar to split_dim but also respects a maximum number of splits. * Useful when there's a limit on parallelization due to hardware constraints. - * + * * @param dimension The dimension size to split - * @param min_size Minimum size per chunk after splitting + * @param min_size Minimum size per chunk after splitting * @param max_splits Maximum number of splits allowed (0 = no limit) * @return The split factor that respects both constraints */ -inline std::size_t split_dim_with_max(std::size_t dimension, - std::size_t min_size, - std::size_t max_splits = 0) +inline std::size_t +split_dim_with_max(std::size_t dimension, std::size_t min_size, std::size_t max_splits = 0) { // Make a copy since split_dim modifies the value - std::size_t remaining = dimension; + std::size_t remaining = dimension; std::size_t num_splits = split_dim(remaining, min_size); - + // If no max constraint or already within limit, return as is if(max_splits == 0 || num_splits <= max_splits) return num_splits; - + // Reduce splits to respect max_splits constraint auto factors = make_array(2, 3, 5, 7, 11); while(num_splits > max_splits) @@ -106,7 +103,7 @@ inline std::size_t split_dim_with_max(std::size_t dimension, if(num_splits > max_splits && num_splits < 2) break; } - + return num_splits; } diff --git a/test/verify/test_attention_flash_decoding_3d.cpp b/test/verify/test_attention_flash_decoding_3d.cpp index 5471940e1f8..068b2be1688 100644 --- a/test/verify/test_attention_flash_decoding_3d.cpp +++ b/test/verify/test_attention_flash_decoding_3d.cpp @@ -51,16 +51,19 @@ struct test_attention_flash_decoding_3d : verify_programadd_parameter("v", v_shape); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {1, 64, 256} - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); // {1, 64, 1} - rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), - rmax); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), + gemm1); // {1, 64, 1} + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), rmax); auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // {1, 64, 256} - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // {1, 64, 256} - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); // {1, 64, 1} - rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), - rsum); // {1, 64, 256} + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // {1, 64, 256} + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), + exp); // {1, 64, 1} + rsum = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), + rsum); // {1, 64, 256} auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // {1, 64, 256} - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // {1, 64, 32} + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // {1, 64, 32} mm->add_return({gemm2}); return p1; } From cefa6e2a26ede25ebf60befd3c1d8787631a692a Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 01:41:49 -0600 Subject: [PATCH 09/30] fix defaults, remove comments, clean up helper func --- docs/reference/MIGraphX-dev-env-vars.rst | 2 +- src/fuse_attention.cpp | 56 +++++++++++------------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index 39828d74e50..e28abb0aee6 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -183,7 +183,7 @@ Model performance tunable variables change the compilation behavior of a model. | Default: ``0`` (automatic calculation). - | Note: This variable is only used when ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. + | Note: This variable implicitly sets ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. * - | ``MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE`` | Sets the minimum chunk size for automatic split calculation in flash decoding. diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 54a2b15d2b5..eea94bacf9e 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -47,7 +47,18 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); bool is_flash_decoding_enabled() { return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); } -std::size_t get_num_splits() { return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); } +// Get num_splits with priority: struct member > env var > 0 (not set) +std::size_t get_num_splits(const std::optional& member_num_splits) +{ + // struct member var is used for testing + if(member_num_splits.has_value()) + { + return *member_num_splits; + } + + // otherwise return env var value, or 0 if not set + return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); +} // calculate optimal flash decoding splits inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, @@ -475,23 +486,22 @@ struct find_flash_decoding std::size_t actual_groups = groups; if(auto_calculate || groups <= 0) { - // Auto-calculate the optimal number of splits - std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 256); - std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 32); - std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 1024); + // TODO: run experiments to find the optimal values for min_chunk and max_splits + std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); + std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); + std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); - // Check if sequence length meets threshold for flash decoding if(sequence_length < threshold) return; actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); - // Skip if auto-calculation determines no splitting needed + // skip if auto-calculation determines no splitting needed if(actual_groups <= 1) return; } - // Skip if no actual splitting (num_splits must be > 1) + // skip if no actual splitting (actual_groups must be > 1) if(actual_groups <= 1) return; @@ -848,35 +858,21 @@ void fuse_attention::apply(module_pass_manager& mpm) const std::size_t num_splits = 0; bool auto_calculate = false; - // for testing, check if the number of splits has been explicitly set - if(flash_decoding_num_splits.has_value()) + std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); + + // enable flash decoding if splits configured or explicitly enabled + if(configured_splits > 0 or is_flash_decoding_enabled()) { flash_enabled = true; - // constructor value provided (for testing) - consider it enabled if > 0 - if(*flash_decoding_num_splits > 0) + + if(configured_splits > 0) { - num_splits = *flash_decoding_num_splits; + num_splits = configured_splits; auto_calculate = false; } else { - // constructor provided 0 or negative - enable with auto-calculation - num_splits = 0; - auto_calculate = true; - } - } - else if(is_flash_decoding_enabled()) - { - // flash decoding is explicitly enabled via environment variable - flash_enabled = true; - - // check if user specified number of splits - num_splits = get_num_splits(); - if(num_splits > 0) - auto_calculate = false; - else - { - // user didn't specify or specified 0/negative - auto-calculate + // 0 means auto-calculate num_splits = 0; auto_calculate = true; } From c749ab3d0c832bd81e5c6245d41fa244b717441c Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 01:44:18 -0600 Subject: [PATCH 10/30] update docs --- docs/reference/MIGraphX-dev-env-vars.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index e28abb0aee6..ec824dc5d31 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -190,7 +190,7 @@ Model performance tunable variables change the compilation behavior of a model. - | Takes a positive integer representing the minimum size of each chunk after splitting. - | Default: ``256``. + | Default: ``32``. | Note: Only used when automatic split calculation is enabled. @@ -199,7 +199,7 @@ Model performance tunable variables change the compilation behavior of a model. - | Takes a positive integer representing the maximum number of splits. - | Default: ``32``. + | Default: ``16``. | Note: Only used when automatic split calculation is enabled. @@ -208,7 +208,7 @@ Model performance tunable variables change the compilation behavior of a model. - | Takes a positive integer. Flash decoding is only applied when the sequence length is greater than or equal to this threshold. - | Default: ``1024``. + | Default: ``32``. | Note: Only used when automatic split calculation is enabled. From 7eeacc3139c73bf4c8a7b440bb375f05efa46892 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 01:58:59 -0600 Subject: [PATCH 11/30] AIMIGRAPHX-341 handle uneven splits --- src/fuse_attention.cpp | 42 ++++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index eea94bacf9e..4f6489a6e60 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -310,9 +310,9 @@ struct find_flash_decoding size_t n = k_lens[ndim - 1]; size_t g = num_groups; - // TODO: handle uneven splits; this is caught in `apply` for now + // Note: sequence length may have been padded to be divisible by num_groups assert(n % g == 0 and - "Key-value sequence length must be divisible by number of splits/groups"); + "Key-value sequence length must be divisible by number of splits/groups (after padding)"); size_t n_split = n / g; transformed_shapes_result result; @@ -505,15 +505,15 @@ struct find_flash_decoding if(actual_groups <= 1) return; - // check if N dimension is evenly divisible by num_splits + // calculate padded sequence length if not evenly divisible + std::size_t padded_sequence_length = sequence_length; + std::size_t padding_needed = 0; if(sequence_length % actual_groups != 0) - return; - - // get Q, V, K shapes from gemms - auto qkv_shapes = get_qkv_shapes(q_param, k_param, v_param); - - // check shapes are ok and get flash decoding transformed shapes (Q', V', K') - auto transform_info = get_transformed_shapes(qkv_shapes, actual_groups); + { + // round up to nearest multiple of actual_groups + padded_sequence_length = ((sequence_length + actual_groups - 1) / actual_groups) * actual_groups; + padding_needed = padded_sequence_length - sequence_length; + } // create mapping from submodule params to main module inputs auto group_inputs = attn_group_ins->inputs(); @@ -524,6 +524,28 @@ struct find_flash_decoding auto k = map_param_to_main.at(k_param); auto v = map_param_to_main.at(v_param); + // pad K and V if necessary + if(padding_needed > 0) + { + // K shape: [B, k, N] or [B, H, k, N] for 4D. Padding on N + auto k_ndim = k->get_shape().ndim(); + std::vector k_pads(2 * k_ndim, 0); + k_pads[k_ndim + k_ndim - 1] = padding_needed; // pad right on last dim + k = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", k_pads}}), k); + + // V shape: [B, N, D] or [B, H, N, D] for 4D + auto v_ndim = v->get_shape().ndim(); + std::vector v_pads(2 * v_ndim, 0); + v_pads[v_ndim + v_ndim - 2] = padding_needed; // pad right on N dim + v = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", v_pads}}), v); + } + + // get Q, K, V shapes (using potentially padded K and V) + auto qkv_shapes = get_qkv_shapes(q, k, v); + + // check shapes are ok and get flash decoding transformed shapes (Q', V', K') + auto transform_info = get_transformed_shapes(qkv_shapes, actual_groups); + // insert reshape operations before group, for Q, K, V auto q_ndim = q->get_shape().lens().size(); int64_t g_axis = q_ndim - 2; From 2824c923b4508291a75b6337a4425f7916eab5b3 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 02:00:51 -0600 Subject: [PATCH 12/30] format --- src/fuse_attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 4f6489a6e60..04fa83dcd16 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -886,7 +886,7 @@ void fuse_attention::apply(module_pass_manager& mpm) const if(configured_splits > 0 or is_flash_decoding_enabled()) { flash_enabled = true; - + if(configured_splits > 0) { num_splits = configured_splits; From 23f5cbb79181f8cb45eb2076477c5e1c606c9116 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 5 Dec 2025 02:08:06 -0600 Subject: [PATCH 13/30] update doc --- docs/reference/MIGraphX-dev-env-vars.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index ec824dc5d31..5524ee45cf5 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -183,7 +183,7 @@ Model performance tunable variables change the compilation behavior of a model. | Default: ``0`` (automatic calculation). - | Note: This variable implicitly sets ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. + | Note: This variable implicitly sets ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. If not set, and ``MIGRAPHX_FLASH_DECODING_ENABLED=1``, the number of splits will be automatically calculated. * - | ``MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE`` | Sets the minimum chunk size for automatic split calculation in flash decoding. From ff755dd0eb9f88e62f149be35754cdcee4e1de7e Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 10:59:20 -0600 Subject: [PATCH 14/30] format --- src/fuse_attention.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 04fa83dcd16..45a1bfa85dc 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -311,8 +311,8 @@ struct find_flash_decoding size_t g = num_groups; // Note: sequence length may have been padded to be divisible by num_groups - assert(n % g == 0 and - "Key-value sequence length must be divisible by number of splits/groups (after padding)"); + assert(n % g == 0 and "Key-value sequence length must be divisible by number of " + "splits/groups (after padding)"); size_t n_split = n / g; transformed_shapes_result result; @@ -511,8 +511,9 @@ struct find_flash_decoding if(sequence_length % actual_groups != 0) { // round up to nearest multiple of actual_groups - padded_sequence_length = ((sequence_length + actual_groups - 1) / actual_groups) * actual_groups; - padding_needed = padded_sequence_length - sequence_length; + padded_sequence_length = + ((sequence_length + actual_groups - 1) / actual_groups) * actual_groups; + padding_needed = padded_sequence_length - sequence_length; } // create mapping from submodule params to main module inputs From 5d08ab028086e042477fe9ef2846a06a7bd84f5d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 11:18:56 -0600 Subject: [PATCH 15/30] add helper func for ceil; cursor made test file --- src/fuse_attention.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 45a1bfa85dc..25f73bc13fe 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -511,9 +512,8 @@ struct find_flash_decoding if(sequence_length % actual_groups != 0) { // round up to nearest multiple of actual_groups - padded_sequence_length = - ((sequence_length + actual_groups - 1) / actual_groups) * actual_groups; - padding_needed = padded_sequence_length - sequence_length; + padded_sequence_length = ceil_mul_of(sequence_length, actual_groups); + padding_needed = padded_sequence_length - sequence_length; } // create mapping from submodule params to main module inputs From fe2cb8fed7f91e07ef5ed64366ea0d39602d2930 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 18:14:42 -0600 Subject: [PATCH 16/30] update helper func and comments; cursor made tests --- src/fuse_attention.cpp | 2 +- src/include/migraphx/generic_float.hpp | 7 + src/include/migraphx/split_factor.hpp | 77 ++---- src/targets/gpu/jit/reduce.cpp | 3 +- test/math_utils_test.cpp | 347 +++++++++++++++++++++++++ 5 files changed, 377 insertions(+), 59 deletions(-) create mode 100644 test/math_utils_test.cpp diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 25f73bc13fe..02b5bd1373c 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -66,7 +66,7 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, std::size_t min_chunk_size, std::size_t max_splits) { - return split_dim_with_max(sequence_length, min_chunk_size, max_splits); + return split_dim(sequence_length, min_chunk_size, max_splits); } // TODO: Write this in matcher.hpp as a general matcher for iterating through inputs diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp index b077d16631b..c70621956d4 100644 --- a/src/include/migraphx/generic_float.hpp +++ b/src/include/migraphx/generic_float.hpp @@ -42,6 +42,13 @@ constexpr std::size_t integer_divide_ceil(std::size_t x, std::size_t y) return (x + y - std::size_t{1}) / y; } +// compute the smallest multiple of y that is greater than or equal to x +// this is equivalent to y * ceil(x / y) +constexpr std::size_t ceil_mul_of(std::size_t x, std::size_t y) +{ + return y * integer_divide_ceil(x, y); +} + template struct unsigned_type { diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index 9d34c0cd0f1..b22c87b34a0 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -26,85 +26,50 @@ #include #include +#include #include - +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { /** - * Calculate split factor for a dimension to make it less than min_size. - * - * This function finds the largest divisor that can divide the dimension - * to make it less than min_size. It uses prime factors [2, 3, 5, 7, 11] - * to find good divisors that work well for parallel execution. + * Calculate split factor with or without maximum splits constraint. * * Used by: * - rewrite_topk: Splits large topk operations for better performance * - flash_decoding: Splits attention sequence dimension for parallelization * - split_reduce (GPU JIT): Splits reduction operations * - * @param r The dimension size to split (will be modified to remaining size) - * @param min_size The minimum size threshold - * @return The split factor (number of groups) + * To compute the number of split groups it finds the largest + * divisor that can divide dimension to make it less than min_size. + * + * @param dimension The dimension size to split + * @param min_size Target threshold - splits until remaining size is less than this value + * @param max_splits Target threshold - if reached, returns the smallest split factor greater than or equal to max_splits that evenly divides dimension. Optional + * @return The split factor that respects both constraints */ -inline std::size_t split_dim(std::size_t& r, std::size_t min_size) +inline std::size_t +split_dim(std::size_t dimension, std::size_t min_size, std::size_t max_splits = std::numeric_limits::max()) { std::size_t n = 1; + std::size_t r = dimension; auto factors = make_array(2, 3, 5, 7, 11); - while(r > min_size) + + while(r > min_size and n < max_splits) { // NOLINTNEXTLINE(readability-qualified-auto) - auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + auto it = std::find_if(factors.begin(), factors.end(), + [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; + + // Note: current functionality is to return the smallest split factor greater than max_splits that evenly divides dimension; max_splits + // is not a cap; n can be >= max_splits. r /= *it; n *= *it; } - return n; -} -/** - * Calculate split factor with maximum splits constraint. - * - * Similar to split_dim but also respects a maximum number of splits. - * Useful when there's a limit on parallelization due to hardware constraints. - * - * @param dimension The dimension size to split - * @param min_size Minimum size per chunk after splitting - * @param max_splits Maximum number of splits allowed (0 = no limit) - * @return The split factor that respects both constraints - */ -inline std::size_t -split_dim_with_max(std::size_t dimension, std::size_t min_size, std::size_t max_splits = 0) -{ - // Make a copy since split_dim modifies the value - std::size_t remaining = dimension; - std::size_t num_splits = split_dim(remaining, min_size); - - // If no max constraint or already within limit, return as is - if(max_splits == 0 || num_splits <= max_splits) - return num_splits; - - // Reduce splits to respect max_splits constraint - auto factors = make_array(2, 3, 5, 7, 11); - while(num_splits > max_splits) - { - // Remove the smallest prime factor to reduce splits - for(auto factor : factors) - { - if(num_splits % factor == 0) - { - num_splits /= factor; - remaining *= factor; - break; - } - } - // Safety check to avoid infinite loop - if(num_splits > max_splits && num_splits < 2) - break; - } - - return num_splits; + return n; } } // namespace MIGRAPHX_INLINE_NS diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 4da8c1ad49c..36d14236fca 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -176,8 +176,7 @@ static std::vector split_reduce(const std::vector& inputs, assert(faxis < reduce_shape.lens().size()); auto r = input_shape.lens()[faxis]; - // Use common split_dim_with_max function from split_factor.hpp - std::size_t n = split_dim_with_max(r, min_size, max_splits); + std::size_t n = split_dim(r, min_size, max_splits); assert(n != 1); std::transform( inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { diff --git a/test/math_utils_test.cpp b/test/math_utils_test.cpp new file mode 100644 index 00000000000..9ffab6169e7 --- /dev/null +++ b/test/math_utils_test.cpp @@ -0,0 +1,347 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include "test.hpp" + +TEST_CASE(integer_divide_ceil_basic) +{ + // Test exact division + EXPECT(migraphx::integer_divide_ceil(10, 5) == 2); + EXPECT(migraphx::integer_divide_ceil(12, 4) == 3); + EXPECT(migraphx::integer_divide_ceil(15, 3) == 5); + + // Test division with remainder (should round up) + EXPECT(migraphx::integer_divide_ceil(10, 3) == 4); // 10/3 = 3.33... -> 4 + EXPECT(migraphx::integer_divide_ceil(11, 3) == 4); // 11/3 = 3.66... -> 4 + EXPECT(migraphx::integer_divide_ceil(13, 5) == 3); // 13/5 = 2.6 -> 3 + + // Test with 1 + EXPECT(migraphx::integer_divide_ceil(1, 1) == 1); + EXPECT(migraphx::integer_divide_ceil(5, 1) == 5); + EXPECT(migraphx::integer_divide_ceil(100, 1) == 100); + + // Test edge cases + EXPECT(migraphx::integer_divide_ceil(0, 5) == 0); + EXPECT(migraphx::integer_divide_ceil(1, 2) == 1); + EXPECT(migraphx::integer_divide_ceil(1, 10) == 1); +} + +TEST_CASE(integer_divide_ceil_large_numbers) +{ + // Test with larger numbers + EXPECT(migraphx::integer_divide_ceil(1000, 7) == 143); // 1000/7 = 142.857... -> 143 + EXPECT(migraphx::integer_divide_ceil(1024, 32) == 32); // Exact division + EXPECT(migraphx::integer_divide_ceil(1025, 32) == 33); // 1025/32 = 32.03... -> 33 + + // Test with powers of 2 + EXPECT(migraphx::integer_divide_ceil(127, 8) == 16); // 127/8 = 15.875 -> 16 + EXPECT(migraphx::integer_divide_ceil(128, 8) == 16); // Exact division + EXPECT(migraphx::integer_divide_ceil(129, 8) == 17); // 129/8 = 16.125 -> 17 +} + +TEST_CASE(ceil_mul_of_basic) +{ + // Test exact multiples (no rounding needed) + EXPECT(migraphx::ceil_mul_of(10, 5) == 10); // 10 is already a multiple of 5 + EXPECT(migraphx::ceil_mul_of(12, 4) == 12); // 12 is already a multiple of 4 + EXPECT(migraphx::ceil_mul_of(15, 3) == 15); // 15 is already a multiple of 3 + + // Test rounding up to next multiple + EXPECT(migraphx::ceil_mul_of(11, 5) == 15); // Next multiple of 5 after 11 is 15 + EXPECT(migraphx::ceil_mul_of(13, 4) == 16); // Next multiple of 4 after 13 is 16 + EXPECT(migraphx::ceil_mul_of(17, 3) == 18); // Next multiple of 3 after 17 is 18 + + // Test with 1 (should always return the original number) + EXPECT(migraphx::ceil_mul_of(5, 1) == 5); + EXPECT(migraphx::ceil_mul_of(100, 1) == 100); + + // Test edge cases + EXPECT(migraphx::ceil_mul_of(0, 5) == 0); + EXPECT(migraphx::ceil_mul_of(1, 5) == 5); + EXPECT(migraphx::ceil_mul_of(1, 10) == 10); +} + +TEST_CASE(ceil_mul_of_powers_of_two) +{ + // Test with powers of 2 (common in GPU programming) + EXPECT(migraphx::ceil_mul_of(100, 32) == 128); // 32 * 4 = 128 + EXPECT(migraphx::ceil_mul_of(128, 32) == 128); // Already aligned + EXPECT(migraphx::ceil_mul_of(129, 32) == 160); // 32 * 5 = 160 + + EXPECT(migraphx::ceil_mul_of(250, 64) == 256); // 64 * 4 = 256 + EXPECT(migraphx::ceil_mul_of(256, 64) == 256); // Already aligned + EXPECT(migraphx::ceil_mul_of(257, 64) == 320); // 64 * 5 = 320 + + // Warp size alignment (32 threads) + EXPECT(migraphx::ceil_mul_of(30, 32) == 32); + EXPECT(migraphx::ceil_mul_of(32, 32) == 32); + EXPECT(migraphx::ceil_mul_of(33, 32) == 64); + EXPECT(migraphx::ceil_mul_of(64, 32) == 64); + EXPECT(migraphx::ceil_mul_of(65, 32) == 96); +} + +TEST_CASE(ceil_mul_of_flash_attention_use_case) +{ + // Test the specific use case from flash attention + // Simulating the padding of sequence length to be divisible by number of groups + + // Example 1: sequence_length=100, num_groups=8 + EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 8 * 13 = 104 + + // Example 2: sequence_length=127, num_groups=16 + EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 16 * 8 = 128 + + // Example 3: sequence_length=200, num_groups=32 + EXPECT(migraphx::ceil_mul_of(200, 32) == 224); // 32 * 7 = 224 + + // Example 4: Already divisible + EXPECT(migraphx::ceil_mul_of(192, 32) == 192); // Already divisible by 32 +} + +TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) +{ + // Verify that ceil_mul_of(x, y) == y * integer_divide_ceil(x, y) + // This tests the implementation relationship + + std::size_t test_cases[][2] = { + {10, 3}, + {15, 4}, + {100, 7}, + {256, 32}, + {1000, 13}, + {1, 10}, + {0, 5} + }; + + for(const auto& tc : test_cases) + { + std::size_t x = tc[0]; + std::size_t y = tc[1]; + + std::size_t expected = y * migraphx::integer_divide_ceil(x, y); + std::size_t actual = migraphx::ceil_mul_of(x, y); + + EXPECT(actual == expected); + } +} + +TEST_CASE(ceil_mul_of_large_numbers) +{ + // Test with larger sequence lengths that might appear in real models + EXPECT(migraphx::ceil_mul_of(1024, 16) == 1024); // Already aligned + EXPECT(migraphx::ceil_mul_of(1025, 16) == 1040); // 16 * 65 = 1040 + EXPECT(migraphx::ceil_mul_of(2048, 64) == 2048); // Already aligned + EXPECT(migraphx::ceil_mul_of(2049, 64) == 2112); // 64 * 33 = 2112 + + // Very large numbers + EXPECT(migraphx::ceil_mul_of(10000, 128) == 10112); // 128 * 79 = 10112 + EXPECT(migraphx::ceil_mul_of(16384, 256) == 16384); // Already aligned + EXPECT(migraphx::ceil_mul_of(16385, 256) == 16640); // 256 * 65 = 16640 +} + +TEST_CASE(split_dim_basic) +{ + // Test with no max_splits constraint (using default) + + // Should split 100 into chunks > 10 + // 100 = 2^2 * 5^2; factors: 2,2,5,5 -> splits = 20, remaining = 5 + EXPECT(migraphx::split_dim(100, 10) == 20); // 100/20 = 5, stops because 5 <= 10 + + // Should split 64 into chunks > 10 + // 64 = 2^6; can use 2,2,2 -> splits = 8, remaining = 8 + EXPECT(migraphx::split_dim(64, 10) == 8); // 64/8 = 8, stops because 8 <= 10 + + // Should not split if already small enough + EXPECT(migraphx::split_dim(10, 10) == 1); // 10 is not > 10, so no split + EXPECT(migraphx::split_dim(11, 10) == 11); // 11 is a factor itself, 11/11 = 1 + + // Prime numbers that can't be factored + EXPECT(migraphx::split_dim(13, 10) == 1); // 13 is prime (not in factor list) + EXPECT(migraphx::split_dim(17, 10) == 1); // 17 is prime (not in factor list) + + // Numbers with factors in [2,3,5,7,11] + EXPECT(migraphx::split_dim(30, 5) == 6); // 30 = 2*3*5, splits to 5 + EXPECT(migraphx::split_dim(77, 10) == 77); // can be evenly split into 11 size chunks; next divisor splits to 1 size chunks +} + +TEST_CASE(split_dim_with_max_splits) +{ + // Test with explicit max_splits constraint + // Note: max_splits is NOT a hard cap - function returns smallest split factor > max_splits that evenly divides dimension + + // When split factor would exceed max_splits, returns next valid divisor + EXPECT(migraphx::split_dim(100, 10, 4) == 4); // 100 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + EXPECT(migraphx::split_dim(100, 10, 2) == 2); + + // Max splits doesn't force splitting if min_size constraint would be violated + EXPECT(migraphx::split_dim(20, 10, 4) == 2); // Can only split to 2 (20/2=10, not > 10) + EXPECT(migraphx::split_dim(15, 10, 4) == 3); // 15 = 3*5, splits to 3, remaining = 5 + + // Test with powers of 2 + EXPECT(migraphx::split_dim(128, 10, 8) == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 + EXPECT(migraphx::split_dim(128, 10, 4) == 4); // 128 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + EXPECT(migraphx::split_dim(128, 20, 8) == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 +} + +TEST_CASE(split_dim_edge_cases) +{ + // Test edge cases + + // Very small dimensions + EXPECT(migraphx::split_dim(1, 0) == 1); // 1 can't be split + EXPECT(migraphx::split_dim(2, 0) == 2); // 2/2 = 1 > 0 + EXPECT(migraphx::split_dim(2, 1) == 2); // 2/2 = 1, but we continue while r > min_size, so 2 > 1 allows split + + // Exact boundary conditions + EXPECT(migraphx::split_dim(20, 9) == 4); // 20 = 2^2 * 5, factors 2,2 before 20/4=5 <= 9 + EXPECT(migraphx::split_dim(20, 10) == 2); // 20 = 2^2 * 5, factors 2 before 20/2=10 <= 10 + EXPECT(migraphx::split_dim(21, 10) == 3); // 21 = 3*7, splits by 3 first, 21/3 = 7 <= 10 + + // Large prime numbers + EXPECT(migraphx::split_dim(97, 10) == 1); // 97 is prime + EXPECT(migraphx::split_dim(101, 10) == 1); // 101 is prime +} + +TEST_CASE(split_dim_factorization_order) +{ + // Test that factorization happens in the expected order [2,3,5,7,11] + + // 60 = 2^2 * 3 * 5 + // With min_size=10: 60->30->15->5 (stops because 5 <= 10) + // Factors used: 2, 2, 3 (product = 12) + EXPECT(migraphx::split_dim(60, 10) == 12); // 60/12 = 5 + + // 210 = 2 * 3 * 5 * 7 + // With min_size=20: continues factoring while 210 > 20 + // Factors all: 2*3*5*7 = 210, but stops at 2*3*5 = 30 since 210/30 = 7 <= 20 + EXPECT(migraphx::split_dim(210, 20) == 30); // 210/30 = 7 + + // 462 = 2 * 3 * 7 * 11 + // With min_size=30: 462->231->77->11 (stops because 11 <= 30) + EXPECT(migraphx::split_dim(462, 30) == 42); // 462/42 = 11 <= 30 +} + +TEST_CASE(split_dim_reduce_use_case) +{ + // Test the specific use case from reduce.cpp + // These are realistic values that might appear in reduction operations + + // Large reduction dimension with typical min_size and max_splits + EXPECT(migraphx::split_dim(1024, 64, 16) == 16); // 1024 = 2^10; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + EXPECT(migraphx::split_dim(1000, 64, 16) == 40); // 1000 = 2^3 * 5^3; n=8, r=125 > 64, so next would be n=40 > 16 + + // Smaller dimensions + EXPECT(migraphx::split_dim(256, 32, 8) == 8); // 256 = 2^8; would split to 16 (256/16=16), exceeds max_splits=8 + EXPECT(migraphx::split_dim(200, 32, 8) == 8); // 200 = 2^3 * 5^2; would split to 8 (200/8=25), exceeds max_splits=8 +} + +TEST_CASE(split_dim_flash_attention_use_case) +{ + // Test use cases from flash attention decoding + // Sequence lengths need to be split for parallel processing + + // Typical sequence lengths in attention + EXPECT(migraphx::split_dim(2048, 128, 16) == 16); // 2048 = 2^11; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + EXPECT(migraphx::split_dim(4096, 256, 16) == 16); // 4096 = 2^12; splits to 16 (4096/16=256) + + // Non-aligned sequence lengths + // 1536 = 2^9 * 3; would continue past max_splits=16 to get 1536/24=64 < 128 + EXPECT(migraphx::split_dim(1536, 128, 16) == 16); // 1536 = 2^9 * 3; n=16, r=128, stops (r not > 128) + // 3000 = 2^3 * 3 * 5^3 + EXPECT(migraphx::split_dim(3000, 200, 16) == 24); // 3*5 = 24, 3000/24 = 125 + + // Smaller sequences + EXPECT(migraphx::split_dim(512, 64, 8) == 8); // 512 = 2^9; stops at n=8 (512/8=64) + EXPECT(migraphx::split_dim(768, 64, 8) == 8); // 768 = 2^8 * 3; n=8, r=128 > 64, stops (r not > 128) +} + +TEST_CASE(split_dim_no_limit_vs_explicit_max) +{ + // Verify that default (no limit) and explicit max give same results when max is large + std::size_t large_max = std::numeric_limits::max(); + + // These should produce identical results + EXPECT(migraphx::split_dim(1000, 10) == migraphx::split_dim(1000, 10, large_max)); + EXPECT(migraphx::split_dim(512, 32) == migraphx::split_dim(512, 32, large_max)); + EXPECT(migraphx::split_dim(360, 20) == migraphx::split_dim(360, 20, large_max)); + + // Test that max_splits affects the result (but doesn't necessarily limit it) + // max_splits acts as a threshold - result may exceed it + EXPECT(migraphx::split_dim(1000, 10) != migraphx::split_dim(1000, 10, 8)); + EXPECT(migraphx::split_dim(512, 8) != migraphx::split_dim(512, 8, 16)); +} + +TEST_CASE(split_dim_consistency_check) +{ + // Verify important properties of split_dim + + // Property 1: The algorithm stops when remaining <= min_size + // This means the final remaining might be <= min_size + for(std::size_t dim : {100, 256, 512, 1000, 2048}) + { + for(std::size_t min_size : {8, 16, 32, 64}) + { + std::size_t splits = migraphx::split_dim(dim, min_size); + if(splits > 1) + { + // The algorithm continues while r > min_size, + // so it stops when r <= min_size + // We can't guarantee remaining > min_size, + // but we know the split was valid + EXPECT(splits >= 1); + EXPECT(dim / splits > 0); + } + } + } + + // Property 2: With max_splits, result is smallest valid divisor >= max_splits + // Note: max_splits is NOT a hard cap - it's a threshold like min_size + for(std::size_t dim : {100, 256, 512, 1000}) + { + for(std::size_t max_splits : {4, 8, 16}) + { + std::size_t splits = migraphx::split_dim(dim, 10, max_splits); + // Result should evenly divide dimension + EXPECT(dim % splits == 0); + // Result should make remaining size < min_size (10) + EXPECT(dim / splits < 10 || splits >= max_splits); + } + } + + // Property 3: Increasing min_size decreases or maintains splits + for(std::size_t dim : {256, 512, 1024}) + { + std::size_t splits_8 = migraphx::split_dim(dim, 8); + std::size_t splits_16 = migraphx::split_dim(dim, 16); + std::size_t splits_32 = migraphx::split_dim(dim, 32); + + EXPECT(splits_8 >= splits_16); + EXPECT(splits_16 >= splits_32); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 8e0d4546af5118b296445f336c1a9ac3fd35967d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 22:46:13 -0600 Subject: [PATCH 17/30] make split_dim by ref --- src/fuse_attention.cpp | 24 +-- src/include/migraphx/split_factor.hpp | 16 +- src/rewrite_topk.cpp | 1 + src/targets/gpu/jit/reduce.cpp | 2 +- test/math_utils_test.cpp | 209 ++++++++++++++++++++------ 5 files changed, 177 insertions(+), 75 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 02b5bd1373c..37050478aa9 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -66,7 +66,8 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, std::size_t min_chunk_size, std::size_t max_splits) { - return split_dim(sequence_length, min_chunk_size, max_splits); + std::size_t r = sequence_length; + return split_dim(r, min_chunk_size, max_splits); } // TODO: Write this in matcher.hpp as a general matcher for iterating through inputs @@ -244,9 +245,8 @@ struct find_attention struct find_flash_decoding { - // number of groups (0 or negative means auto-calculate) + // number of groups (0 means auto-calculate) std::size_t groups; - bool auto_calculate = false; auto matcher() const { @@ -485,7 +485,7 @@ struct find_flash_decoding // Determine actual number of splits to use std::size_t actual_groups = groups; - if(auto_calculate || groups <= 0) + if(groups == 0) { // TODO: run experiments to find the optimal values for min_chunk and max_splits std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); @@ -879,7 +879,6 @@ void fuse_attention::apply(module_pass_manager& mpm) const // Apply flash decoding if enabled bool flash_enabled = false; std::size_t num_splits = 0; - bool auto_calculate = false; std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); @@ -888,23 +887,14 @@ void fuse_attention::apply(module_pass_manager& mpm) const { flash_enabled = true; - if(configured_splits > 0) - { - num_splits = configured_splits; - auto_calculate = false; - } - else - { - // 0 means auto-calculate - num_splits = 0; - auto_calculate = true; - } + // 0 means auto-calculate + num_splits = configured_splits > 0 ? configured_splits : 0; } if(flash_enabled) { match::find_matches( - mpm, find_flash_decoding{.groups = num_splits, .auto_calculate = auto_calculate}); + mpm, find_flash_decoding{.groups = num_splits}); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index b22c87b34a0..ab4a4273d85 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -43,32 +43,26 @@ inline namespace MIGRAPHX_INLINE_NS { * To compute the number of split groups it finds the largest * divisor that can divide dimension to make it less than min_size. * - * @param dimension The dimension size to split + * @param r The value to split. This is passed by reference and will be modified to the remaining value after splitting. * @param min_size Target threshold - splits until remaining size is less than this value * @param max_splits Target threshold - if reached, returns the smallest split factor greater than or equal to max_splits that evenly divides dimension. Optional * @return The split factor that respects both constraints */ -inline std::size_t -split_dim(std::size_t dimension, std::size_t min_size, std::size_t max_splits = std::numeric_limits::max()) + +inline std::size_t split_dim(std::size_t& r, std::size_t min_size, std::size_t max_splits = std::numeric_limits::max()) { std::size_t n = 1; - std::size_t r = dimension; auto factors = make_array(2, 3, 5, 7, 11); - while(r > min_size and n < max_splits) { // NOLINTNEXTLINE(readability-qualified-auto) - auto it = std::find_if(factors.begin(), factors.end(), - [&](auto d) { return r % d == 0; }); + auto it = + std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; - - // Note: current functionality is to return the smallest split factor greater than max_splits that evenly divides dimension; max_splits - // is not a cap; n can be >= max_splits. r /= *it; n *= *it; } - return n; } diff --git a/src/rewrite_topk.cpp b/src/rewrite_topk.cpp index e94f37afdc7..d3075a3bf6b 100644 --- a/src/rewrite_topk.cpp +++ b/src/rewrite_topk.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include namespace migraphx { diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index 36d14236fca..fa417c66480 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -175,7 +175,7 @@ static std::vector split_reduce(const std::vector& inputs, assert(faxis < reduce_shape.lens().size()); - auto r = input_shape.lens()[faxis]; + std::size_t r = input_shape.lens()[faxis]; std::size_t n = split_dim(r, min_size, max_splits); assert(n != 1); std::transform( diff --git a/test/math_utils_test.cpp b/test/math_utils_test.cpp index 9ffab6169e7..7a7de9fa1ee 100644 --- a/test/math_utils_test.cpp +++ b/test/math_utils_test.cpp @@ -169,23 +169,42 @@ TEST_CASE(split_dim_basic) // Should split 100 into chunks > 10 // 100 = 2^2 * 5^2; factors: 2,2,5,5 -> splits = 20, remaining = 5 - EXPECT(migraphx::split_dim(100, 10) == 20); // 100/20 = 5, stops because 5 <= 10 + std::size_t dim100 = 100; + std::size_t result100 = migraphx::split_dim(dim100, 10); + EXPECT(result100 == 20); // 100/20 = 5, stops because 5 <= 10 // Should split 64 into chunks > 10 // 64 = 2^6; can use 2,2,2 -> splits = 8, remaining = 8 - EXPECT(migraphx::split_dim(64, 10) == 8); // 64/8 = 8, stops because 8 <= 10 + std::size_t dim64 = 64; + std::size_t result64 = migraphx::split_dim(dim64, 10); + EXPECT(result64 == 8); // 64/8 = 8, stops because 8 <= 10 // Should not split if already small enough - EXPECT(migraphx::split_dim(10, 10) == 1); // 10 is not > 10, so no split - EXPECT(migraphx::split_dim(11, 10) == 11); // 11 is a factor itself, 11/11 = 1 + std::size_t dim10 = 10; + std::size_t result10 = migraphx::split_dim(dim10, 10); + EXPECT(result10 == 1); // 10 is not > 10, so no split + + std::size_t dim11 = 11; + std::size_t result11 = migraphx::split_dim(dim11, 10); + EXPECT(result11 == 11); // 11 is a factor itself, 11/11 = 1 // Prime numbers that can't be factored - EXPECT(migraphx::split_dim(13, 10) == 1); // 13 is prime (not in factor list) - EXPECT(migraphx::split_dim(17, 10) == 1); // 17 is prime (not in factor list) + std::size_t dim13 = 13; + std::size_t result13 = migraphx::split_dim(dim13, 10); + EXPECT(result13 == 1); // 13 is prime (not in factor list) + + std::size_t dim17 = 17; + std::size_t result17 = migraphx::split_dim(dim17, 10); + EXPECT(result17 == 1); // 17 is prime (not in factor list) // Numbers with factors in [2,3,5,7,11] - EXPECT(migraphx::split_dim(30, 5) == 6); // 30 = 2*3*5, splits to 5 - EXPECT(migraphx::split_dim(77, 10) == 77); // can be evenly split into 11 size chunks; next divisor splits to 1 size chunks + std::size_t dim30 = 30; + std::size_t result30 = migraphx::split_dim(dim30, 5); + EXPECT(result30 == 6); // 30 = 2*3*5, splits to 5 + + std::size_t dim77 = 77; + std::size_t result77 = migraphx::split_dim(dim77, 10); + EXPECT(result77 == 77); // can be evenly split into 11 size chunks; next divisor splits to 1 size chunks } TEST_CASE(split_dim_with_max_splits) @@ -194,17 +213,35 @@ TEST_CASE(split_dim_with_max_splits) // Note: max_splits is NOT a hard cap - function returns smallest split factor > max_splits that evenly divides dimension // When split factor would exceed max_splits, returns next valid divisor - EXPECT(migraphx::split_dim(100, 10, 4) == 4); // 100 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 - EXPECT(migraphx::split_dim(100, 10, 2) == 2); + std::size_t dim100a = 100; + std::size_t result100a = migraphx::split_dim(dim100a, 10, 4); + EXPECT(result100a == 4); // 100 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + + std::size_t dim100b = 100; + std::size_t result100b = migraphx::split_dim(dim100b, 10, 2); + EXPECT(result100b == 2); // Max splits doesn't force splitting if min_size constraint would be violated - EXPECT(migraphx::split_dim(20, 10, 4) == 2); // Can only split to 2 (20/2=10, not > 10) - EXPECT(migraphx::split_dim(15, 10, 4) == 3); // 15 = 3*5, splits to 3, remaining = 5 + std::size_t dim20 = 20; + std::size_t result20 = migraphx::split_dim(dim20, 10, 4); + EXPECT(result20 == 2); // Can only split to 2 (20/2=10, not > 10) + + std::size_t dim15 = 15; + std::size_t result15 = migraphx::split_dim(dim15, 10, 4); + EXPECT(result15 == 3); // 15 = 3*5, splits to 3, remaining = 5 // Test with powers of 2 - EXPECT(migraphx::split_dim(128, 10, 8) == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 - EXPECT(migraphx::split_dim(128, 10, 4) == 4); // 128 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 - EXPECT(migraphx::split_dim(128, 20, 8) == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 + std::size_t dim128a = 128; + std::size_t result128a = migraphx::split_dim(dim128a, 10, 8); + EXPECT(result128a == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 + + std::size_t dim128b = 128; + std::size_t result128b = migraphx::split_dim(dim128b, 10, 4); + EXPECT(result128b == 4); // 128 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + + std::size_t dim128c = 128; + std::size_t result128c = migraphx::split_dim(dim128c, 20, 8); + EXPECT(result128c == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 } TEST_CASE(split_dim_edge_cases) @@ -212,18 +249,39 @@ TEST_CASE(split_dim_edge_cases) // Test edge cases // Very small dimensions - EXPECT(migraphx::split_dim(1, 0) == 1); // 1 can't be split - EXPECT(migraphx::split_dim(2, 0) == 2); // 2/2 = 1 > 0 - EXPECT(migraphx::split_dim(2, 1) == 2); // 2/2 = 1, but we continue while r > min_size, so 2 > 1 allows split + std::size_t dim1 = 1; + std::size_t result1 = migraphx::split_dim(dim1, 0); + EXPECT(result1 == 1); // 1 can't be split + + std::size_t dim2a = 2; + std::size_t result2a = migraphx::split_dim(dim2a, 0); + EXPECT(result2a == 2); // 2/2 = 1 > 0 + + std::size_t dim2b = 2; + std::size_t result2b = migraphx::split_dim(dim2b, 1); + EXPECT(result2b == 2); // 2/2 = 1, but we continue while r > min_size, so 2 > 1 allows split // Exact boundary conditions - EXPECT(migraphx::split_dim(20, 9) == 4); // 20 = 2^2 * 5, factors 2,2 before 20/4=5 <= 9 - EXPECT(migraphx::split_dim(20, 10) == 2); // 20 = 2^2 * 5, factors 2 before 20/2=10 <= 10 - EXPECT(migraphx::split_dim(21, 10) == 3); // 21 = 3*7, splits by 3 first, 21/3 = 7 <= 10 + std::size_t dim20a = 20; + std::size_t result20a = migraphx::split_dim(dim20a, 9); + EXPECT(result20a == 4); // 20 = 2^2 * 5, factors 2,2 before 20/4=5 <= 9 + + std::size_t dim20b = 20; + std::size_t result20b = migraphx::split_dim(dim20b, 10); + EXPECT(result20b == 2); // 20 = 2^2 * 5, factors 2 before 20/2=10 <= 10 + + std::size_t dim21 = 21; + std::size_t result21 = migraphx::split_dim(dim21, 10); + EXPECT(result21 == 3); // 21 = 3*7, splits by 3 first, 21/3 = 7 <= 10 // Large prime numbers - EXPECT(migraphx::split_dim(97, 10) == 1); // 97 is prime - EXPECT(migraphx::split_dim(101, 10) == 1); // 101 is prime + std::size_t dim97 = 97; + std::size_t result97 = migraphx::split_dim(dim97, 10); + EXPECT(result97 == 1); // 97 is prime + + std::size_t dim101 = 101; + std::size_t result101 = migraphx::split_dim(dim101, 10); + EXPECT(result101 == 1); // 101 is prime } TEST_CASE(split_dim_factorization_order) @@ -233,16 +291,22 @@ TEST_CASE(split_dim_factorization_order) // 60 = 2^2 * 3 * 5 // With min_size=10: 60->30->15->5 (stops because 5 <= 10) // Factors used: 2, 2, 3 (product = 12) - EXPECT(migraphx::split_dim(60, 10) == 12); // 60/12 = 5 + std::size_t dim60 = 60; + std::size_t result60 = migraphx::split_dim(dim60, 10); + EXPECT(result60 == 12); // 60/12 = 5 // 210 = 2 * 3 * 5 * 7 // With min_size=20: continues factoring while 210 > 20 // Factors all: 2*3*5*7 = 210, but stops at 2*3*5 = 30 since 210/30 = 7 <= 20 - EXPECT(migraphx::split_dim(210, 20) == 30); // 210/30 = 7 + std::size_t dim210 = 210; + std::size_t result210 = migraphx::split_dim(dim210, 20); + EXPECT(result210 == 30); // 210/30 = 7 // 462 = 2 * 3 * 7 * 11 // With min_size=30: 462->231->77->11 (stops because 11 <= 30) - EXPECT(migraphx::split_dim(462, 30) == 42); // 462/42 = 11 <= 30 + std::size_t dim462 = 462; + std::size_t result462 = migraphx::split_dim(dim462, 30); + EXPECT(result462 == 42); // 462/42 = 11 <= 30 } TEST_CASE(split_dim_reduce_use_case) @@ -251,12 +315,22 @@ TEST_CASE(split_dim_reduce_use_case) // These are realistic values that might appear in reduction operations // Large reduction dimension with typical min_size and max_splits - EXPECT(migraphx::split_dim(1024, 64, 16) == 16); // 1024 = 2^10; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops - EXPECT(migraphx::split_dim(1000, 64, 16) == 40); // 1000 = 2^3 * 5^3; n=8, r=125 > 64, so next would be n=40 > 16 + std::size_t dim1024 = 1024; + std::size_t result1024 = migraphx::split_dim(dim1024, 64, 16); + EXPECT(result1024 == 16); // 1024 = 2^10; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + + std::size_t dim1000 = 1000; + std::size_t result1000 = migraphx::split_dim(dim1000, 64, 16); + EXPECT(result1000 == 40); // 1000 = 2^3 * 5^3; n=8, r=125 > 64, so next would be n=40 > 16 // Smaller dimensions - EXPECT(migraphx::split_dim(256, 32, 8) == 8); // 256 = 2^8; would split to 16 (256/16=16), exceeds max_splits=8 - EXPECT(migraphx::split_dim(200, 32, 8) == 8); // 200 = 2^3 * 5^2; would split to 8 (200/8=25), exceeds max_splits=8 + std::size_t dim256 = 256; + std::size_t result256 = migraphx::split_dim(dim256, 32, 8); + EXPECT(result256 == 8); // 256 = 2^8; would split to 16 (256/16=16), exceeds max_splits=8 + + std::size_t dim200 = 200; + std::size_t result200 = migraphx::split_dim(dim200, 32, 8); + EXPECT(result200 == 8); // 200 = 2^3 * 5^2; would split to 8 (200/8=25), exceeds max_splits=8 } TEST_CASE(split_dim_flash_attention_use_case) @@ -265,18 +339,33 @@ TEST_CASE(split_dim_flash_attention_use_case) // Sequence lengths need to be split for parallel processing // Typical sequence lengths in attention - EXPECT(migraphx::split_dim(2048, 128, 16) == 16); // 2048 = 2^11; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops - EXPECT(migraphx::split_dim(4096, 256, 16) == 16); // 4096 = 2^12; splits to 16 (4096/16=256) + std::size_t dim2048 = 2048; + std::size_t result2048 = migraphx::split_dim(dim2048, 128, 16); + EXPECT(result2048 == 16); // 2048 = 2^11; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + + std::size_t dim4096 = 4096; + std::size_t result4096 = migraphx::split_dim(dim4096, 256, 16); + EXPECT(result4096 == 16); // 4096 = 2^12; splits to 16 (4096/16=256) // Non-aligned sequence lengths // 1536 = 2^9 * 3; would continue past max_splits=16 to get 1536/24=64 < 128 - EXPECT(migraphx::split_dim(1536, 128, 16) == 16); // 1536 = 2^9 * 3; n=16, r=128, stops (r not > 128) + std::size_t dim1536 = 1536; + std::size_t result1536 = migraphx::split_dim(dim1536, 128, 16); + EXPECT(result1536 == 16); // 1536 = 2^9 * 3; n=16, r=128, stops (r not > 128) + // 3000 = 2^3 * 3 * 5^3 - EXPECT(migraphx::split_dim(3000, 200, 16) == 24); // 3*5 = 24, 3000/24 = 125 + std::size_t dim3000 = 3000; + std::size_t result3000 = migraphx::split_dim(dim3000, 200, 16); + EXPECT(result3000 == 24); // 3*5 = 24, 3000/24 = 125 // Smaller sequences - EXPECT(migraphx::split_dim(512, 64, 8) == 8); // 512 = 2^9; stops at n=8 (512/8=64) - EXPECT(migraphx::split_dim(768, 64, 8) == 8); // 768 = 2^8 * 3; n=8, r=128 > 64, stops (r not > 128) + std::size_t dim512 = 512; + std::size_t result512 = migraphx::split_dim(dim512, 64, 8); + EXPECT(result512 == 8); // 512 = 2^9; stops at n=8 (512/8=64) + + std::size_t dim768 = 768; + std::size_t result768 = migraphx::split_dim(dim768, 64, 8); + EXPECT(result768 == 8); // 768 = 2^8 * 3; n=8, r=128 > 64, stops (r not > 128) } TEST_CASE(split_dim_no_limit_vs_explicit_max) @@ -285,14 +374,37 @@ TEST_CASE(split_dim_no_limit_vs_explicit_max) std::size_t large_max = std::numeric_limits::max(); // These should produce identical results - EXPECT(migraphx::split_dim(1000, 10) == migraphx::split_dim(1000, 10, large_max)); - EXPECT(migraphx::split_dim(512, 32) == migraphx::split_dim(512, 32, large_max)); - EXPECT(migraphx::split_dim(360, 20) == migraphx::split_dim(360, 20, large_max)); + std::size_t dim1000a = 1000; + std::size_t result1000a = migraphx::split_dim(dim1000a, 10); + std::size_t dim1000b = 1000; + std::size_t result1000b = migraphx::split_dim(dim1000b, 10, large_max); + EXPECT(result1000a == result1000b); + + std::size_t dim512a = 512; + std::size_t result512a = migraphx::split_dim(dim512a, 32); + std::size_t dim512b = 512; + std::size_t result512b = migraphx::split_dim(dim512b, 32, large_max); + EXPECT(result512a == result512b); + + std::size_t dim360a = 360; + std::size_t result360a = migraphx::split_dim(dim360a, 20); + std::size_t dim360b = 360; + std::size_t result360b = migraphx::split_dim(dim360b, 20, large_max); + EXPECT(result360a == result360b); // Test that max_splits affects the result (but doesn't necessarily limit it) // max_splits acts as a threshold - result may exceed it - EXPECT(migraphx::split_dim(1000, 10) != migraphx::split_dim(1000, 10, 8)); - EXPECT(migraphx::split_dim(512, 8) != migraphx::split_dim(512, 8, 16)); + std::size_t dim1000c = 1000; + std::size_t result1000c = migraphx::split_dim(dim1000c, 10); + std::size_t dim1000d = 1000; + std::size_t result1000d = migraphx::split_dim(dim1000d, 10, 8); + EXPECT(result1000c != result1000d); + + std::size_t dim512c = 512; + std::size_t result512c = migraphx::split_dim(dim512c, 8); + std::size_t dim512d = 512; + std::size_t result512d = migraphx::split_dim(dim512d, 8, 16); + EXPECT(result512c != result512d); } TEST_CASE(split_dim_consistency_check) @@ -305,7 +417,8 @@ TEST_CASE(split_dim_consistency_check) { for(std::size_t min_size : {8, 16, 32, 64}) { - std::size_t splits = migraphx::split_dim(dim, min_size); + std::size_t dim_copy = dim; + std::size_t splits = migraphx::split_dim(dim_copy, min_size); if(splits > 1) { // The algorithm continues while r > min_size, @@ -324,7 +437,8 @@ TEST_CASE(split_dim_consistency_check) { for(std::size_t max_splits : {4, 8, 16}) { - std::size_t splits = migraphx::split_dim(dim, 10, max_splits); + std::size_t dim_copy = dim; + std::size_t splits = migraphx::split_dim(dim_copy, 10, max_splits); // Result should evenly divide dimension EXPECT(dim % splits == 0); // Result should make remaining size < min_size (10) @@ -335,9 +449,12 @@ TEST_CASE(split_dim_consistency_check) // Property 3: Increasing min_size decreases or maintains splits for(std::size_t dim : {256, 512, 1024}) { - std::size_t splits_8 = migraphx::split_dim(dim, 8); - std::size_t splits_16 = migraphx::split_dim(dim, 16); - std::size_t splits_32 = migraphx::split_dim(dim, 32); + std::size_t dim_copy1 = dim; + std::size_t splits_8 = migraphx::split_dim(dim_copy1, 8); + std::size_t dim_copy2 = dim; + std::size_t splits_16 = migraphx::split_dim(dim_copy2, 16); + std::size_t dim_copy3 = dim; + std::size_t splits_32 = migraphx::split_dim(dim_copy3, 32); EXPECT(splits_8 >= splits_16); EXPECT(splits_16 >= splits_32); From 8a6dd3b0ec275be7b79858f8d79d38db7b4c0825 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 23:01:34 -0600 Subject: [PATCH 18/30] format --- src/fuse_attention.cpp | 8 +- src/include/migraphx/split_factor.hpp | 10 +- test/math_utils_test.cpp | 257 +++++++++++++------------- 3 files changed, 139 insertions(+), 136 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 37050478aa9..ea0f7aa76db 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -506,14 +506,12 @@ struct find_flash_decoding if(actual_groups <= 1) return; - // calculate padded sequence length if not evenly divisible - std::size_t padded_sequence_length = sequence_length; - std::size_t padding_needed = 0; + // calculate padding if sequence length not evenly divisible + std::size_t padding_needed = 0; if(sequence_length % actual_groups != 0) { // round up to nearest multiple of actual_groups - padded_sequence_length = ceil_mul_of(sequence_length, actual_groups); - padding_needed = padded_sequence_length - sequence_length; + padding_needed = ceil_mul_of(sequence_length, actual_groups) - sequence_length; } // create mapping from submodule params to main module inputs diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index ab4a4273d85..2571425b482 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -45,19 +45,21 @@ inline namespace MIGRAPHX_INLINE_NS { * * @param r The value to split. This is passed by reference and will be modified to the remaining value after splitting. * @param min_size Target threshold - splits until remaining size is less than this value - * @param max_splits Target threshold - if reached, returns the smallest split factor greater than or equal to max_splits that evenly divides dimension. Optional + * @param max_splits Target threshold - if reached, returns the smallest split factor greater than + * or equal to max_splits that evenly divides dimension. Optional * @return The split factor that respects both constraints */ -inline std::size_t split_dim(std::size_t& r, std::size_t min_size, std::size_t max_splits = std::numeric_limits::max()) +inline std::size_t split_dim(std::size_t& r, + std::size_t min_size, + std::size_t max_splits = std::numeric_limits::max()) { std::size_t n = 1; auto factors = make_array(2, 3, 5, 7, 11); while(r > min_size and n < max_splits) { // NOLINTNEXTLINE(readability-qualified-auto) - auto it = - std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); if(it == factors.end()) break; r /= *it; diff --git a/test/math_utils_test.cpp b/test/math_utils_test.cpp index 7a7de9fa1ee..3750213e170 100644 --- a/test/math_utils_test.cpp +++ b/test/math_utils_test.cpp @@ -35,9 +35,9 @@ TEST_CASE(integer_divide_ceil_basic) EXPECT(migraphx::integer_divide_ceil(15, 3) == 5); // Test division with remainder (should round up) - EXPECT(migraphx::integer_divide_ceil(10, 3) == 4); // 10/3 = 3.33... -> 4 - EXPECT(migraphx::integer_divide_ceil(11, 3) == 4); // 11/3 = 3.66... -> 4 - EXPECT(migraphx::integer_divide_ceil(13, 5) == 3); // 13/5 = 2.6 -> 3 + EXPECT(migraphx::integer_divide_ceil(10, 3) == 4); // 10/3 = 3.33... -> 4 + EXPECT(migraphx::integer_divide_ceil(11, 3) == 4); // 11/3 = 3.66... -> 4 + EXPECT(migraphx::integer_divide_ceil(13, 5) == 3); // 13/5 = 2.6 -> 3 // Test with 1 EXPECT(migraphx::integer_divide_ceil(1, 1) == 1); @@ -53,27 +53,27 @@ TEST_CASE(integer_divide_ceil_basic) TEST_CASE(integer_divide_ceil_large_numbers) { // Test with larger numbers - EXPECT(migraphx::integer_divide_ceil(1000, 7) == 143); // 1000/7 = 142.857... -> 143 - EXPECT(migraphx::integer_divide_ceil(1024, 32) == 32); // Exact division - EXPECT(migraphx::integer_divide_ceil(1025, 32) == 33); // 1025/32 = 32.03... -> 33 + EXPECT(migraphx::integer_divide_ceil(1000, 7) == 143); // 1000/7 = 142.857... -> 143 + EXPECT(migraphx::integer_divide_ceil(1024, 32) == 32); // Exact division + EXPECT(migraphx::integer_divide_ceil(1025, 32) == 33); // 1025/32 = 32.03... -> 33 // Test with powers of 2 - EXPECT(migraphx::integer_divide_ceil(127, 8) == 16); // 127/8 = 15.875 -> 16 - EXPECT(migraphx::integer_divide_ceil(128, 8) == 16); // Exact division - EXPECT(migraphx::integer_divide_ceil(129, 8) == 17); // 129/8 = 16.125 -> 17 + EXPECT(migraphx::integer_divide_ceil(127, 8) == 16); // 127/8 = 15.875 -> 16 + EXPECT(migraphx::integer_divide_ceil(128, 8) == 16); // Exact division + EXPECT(migraphx::integer_divide_ceil(129, 8) == 17); // 129/8 = 16.125 -> 17 } TEST_CASE(ceil_mul_of_basic) { // Test exact multiples (no rounding needed) - EXPECT(migraphx::ceil_mul_of(10, 5) == 10); // 10 is already a multiple of 5 - EXPECT(migraphx::ceil_mul_of(12, 4) == 12); // 12 is already a multiple of 4 - EXPECT(migraphx::ceil_mul_of(15, 3) == 15); // 15 is already a multiple of 3 + EXPECT(migraphx::ceil_mul_of(10, 5) == 10); // 10 is already a multiple of 5 + EXPECT(migraphx::ceil_mul_of(12, 4) == 12); // 12 is already a multiple of 4 + EXPECT(migraphx::ceil_mul_of(15, 3) == 15); // 15 is already a multiple of 3 // Test rounding up to next multiple - EXPECT(migraphx::ceil_mul_of(11, 5) == 15); // Next multiple of 5 after 11 is 15 - EXPECT(migraphx::ceil_mul_of(13, 4) == 16); // Next multiple of 4 after 13 is 16 - EXPECT(migraphx::ceil_mul_of(17, 3) == 18); // Next multiple of 3 after 17 is 18 + EXPECT(migraphx::ceil_mul_of(11, 5) == 15); // Next multiple of 5 after 11 is 15 + EXPECT(migraphx::ceil_mul_of(13, 4) == 16); // Next multiple of 4 after 13 is 16 + EXPECT(migraphx::ceil_mul_of(17, 3) == 18); // Next multiple of 3 after 17 is 18 // Test with 1 (should always return the original number) EXPECT(migraphx::ceil_mul_of(5, 1) == 5); @@ -88,13 +88,13 @@ TEST_CASE(ceil_mul_of_basic) TEST_CASE(ceil_mul_of_powers_of_two) { // Test with powers of 2 (common in GPU programming) - EXPECT(migraphx::ceil_mul_of(100, 32) == 128); // 32 * 4 = 128 - EXPECT(migraphx::ceil_mul_of(128, 32) == 128); // Already aligned - EXPECT(migraphx::ceil_mul_of(129, 32) == 160); // 32 * 5 = 160 + EXPECT(migraphx::ceil_mul_of(100, 32) == 128); // 32 * 4 = 128 + EXPECT(migraphx::ceil_mul_of(128, 32) == 128); // Already aligned + EXPECT(migraphx::ceil_mul_of(129, 32) == 160); // 32 * 5 = 160 - EXPECT(migraphx::ceil_mul_of(250, 64) == 256); // 64 * 4 = 256 - EXPECT(migraphx::ceil_mul_of(256, 64) == 256); // Already aligned - EXPECT(migraphx::ceil_mul_of(257, 64) == 320); // 64 * 5 = 320 + EXPECT(migraphx::ceil_mul_of(250, 64) == 256); // 64 * 4 = 256 + EXPECT(migraphx::ceil_mul_of(256, 64) == 256); // Already aligned + EXPECT(migraphx::ceil_mul_of(257, 64) == 320); // 64 * 5 = 320 // Warp size alignment (32 threads) EXPECT(migraphx::ceil_mul_of(30, 32) == 32); @@ -110,16 +110,16 @@ TEST_CASE(ceil_mul_of_flash_attention_use_case) // Simulating the padding of sequence length to be divisible by number of groups // Example 1: sequence_length=100, num_groups=8 - EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 8 * 13 = 104 + EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 8 * 13 = 104 // Example 2: sequence_length=127, num_groups=16 - EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 16 * 8 = 128 + EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 16 * 8 = 128 // Example 3: sequence_length=200, num_groups=32 - EXPECT(migraphx::ceil_mul_of(200, 32) == 224); // 32 * 7 = 224 + EXPECT(migraphx::ceil_mul_of(200, 32) == 224); // 32 * 7 = 224 // Example 4: Already divisible - EXPECT(migraphx::ceil_mul_of(192, 32) == 192); // Already divisible by 32 + EXPECT(migraphx::ceil_mul_of(192, 32) == 192); // Already divisible by 32 } TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) @@ -128,14 +128,7 @@ TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) // This tests the implementation relationship std::size_t test_cases[][2] = { - {10, 3}, - {15, 4}, - {100, 7}, - {256, 32}, - {1000, 13}, - {1, 10}, - {0, 5} - }; + {10, 3}, {15, 4}, {100, 7}, {256, 32}, {1000, 13}, {1, 10}, {0, 5}}; for(const auto& tc : test_cases) { @@ -143,7 +136,7 @@ TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) std::size_t y = tc[1]; std::size_t expected = y * migraphx::integer_divide_ceil(x, y); - std::size_t actual = migraphx::ceil_mul_of(x, y); + std::size_t actual = migraphx::ceil_mul_of(x, y); EXPECT(actual == expected); } @@ -152,15 +145,15 @@ TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) TEST_CASE(ceil_mul_of_large_numbers) { // Test with larger sequence lengths that might appear in real models - EXPECT(migraphx::ceil_mul_of(1024, 16) == 1024); // Already aligned - EXPECT(migraphx::ceil_mul_of(1025, 16) == 1040); // 16 * 65 = 1040 - EXPECT(migraphx::ceil_mul_of(2048, 64) == 2048); // Already aligned - EXPECT(migraphx::ceil_mul_of(2049, 64) == 2112); // 64 * 33 = 2112 + EXPECT(migraphx::ceil_mul_of(1024, 16) == 1024); // Already aligned + EXPECT(migraphx::ceil_mul_of(1025, 16) == 1040); // 16 * 65 = 1040 + EXPECT(migraphx::ceil_mul_of(2048, 64) == 2048); // Already aligned + EXPECT(migraphx::ceil_mul_of(2049, 64) == 2112); // 64 * 33 = 2112 // Very large numbers - EXPECT(migraphx::ceil_mul_of(10000, 128) == 10112); // 128 * 79 = 10112 - EXPECT(migraphx::ceil_mul_of(16384, 256) == 16384); // Already aligned - EXPECT(migraphx::ceil_mul_of(16385, 256) == 16640); // 256 * 65 = 16640 + EXPECT(migraphx::ceil_mul_of(10000, 128) == 10112); // 128 * 79 = 10112 + EXPECT(migraphx::ceil_mul_of(16384, 256) == 16384); // Already aligned + EXPECT(migraphx::ceil_mul_of(16385, 256) == 16640); // 256 * 65 = 16640 } TEST_CASE(split_dim_basic) @@ -169,79 +162,87 @@ TEST_CASE(split_dim_basic) // Should split 100 into chunks > 10 // 100 = 2^2 * 5^2; factors: 2,2,5,5 -> splits = 20, remaining = 5 - std::size_t dim100 = 100; + std::size_t dim100 = 100; std::size_t result100 = migraphx::split_dim(dim100, 10); - EXPECT(result100 == 20); // 100/20 = 5, stops because 5 <= 10 + EXPECT(result100 == 20); // 100/20 = 5, stops because 5 <= 10 - // Should split 64 into chunks > 10 + // Should split 64 into chunks > 10 // 64 = 2^6; can use 2,2,2 -> splits = 8, remaining = 8 - std::size_t dim64 = 64; + std::size_t dim64 = 64; std::size_t result64 = migraphx::split_dim(dim64, 10); - EXPECT(result64 == 8); // 64/8 = 8, stops because 8 <= 10 + EXPECT(result64 == 8); // 64/8 = 8, stops because 8 <= 10 // Should not split if already small enough - std::size_t dim10 = 10; + std::size_t dim10 = 10; std::size_t result10 = migraphx::split_dim(dim10, 10); - EXPECT(result10 == 1); // 10 is not > 10, so no split + EXPECT(result10 == 1); // 10 is not > 10, so no split - std::size_t dim11 = 11; + std::size_t dim11 = 11; std::size_t result11 = migraphx::split_dim(dim11, 10); - EXPECT(result11 == 11); // 11 is a factor itself, 11/11 = 1 + EXPECT(result11 == 11); // 11 is a factor itself, 11/11 = 1 // Prime numbers that can't be factored - std::size_t dim13 = 13; + std::size_t dim13 = 13; std::size_t result13 = migraphx::split_dim(dim13, 10); - EXPECT(result13 == 1); // 13 is prime (not in factor list) + EXPECT(result13 == 1); // 13 is prime (not in factor list) - std::size_t dim17 = 17; + std::size_t dim17 = 17; std::size_t result17 = migraphx::split_dim(dim17, 10); - EXPECT(result17 == 1); // 17 is prime (not in factor list) + EXPECT(result17 == 1); // 17 is prime (not in factor list) // Numbers with factors in [2,3,5,7,11] - std::size_t dim30 = 30; + std::size_t dim30 = 30; std::size_t result30 = migraphx::split_dim(dim30, 5); - EXPECT(result30 == 6); // 30 = 2*3*5, splits to 5 + EXPECT(result30 == 6); // 30 = 2*3*5, splits to 5 - std::size_t dim77 = 77; + std::size_t dim77 = 77; std::size_t result77 = migraphx::split_dim(dim77, 10); - EXPECT(result77 == 77); // can be evenly split into 11 size chunks; next divisor splits to 1 size chunks + EXPECT(result77 == + 77); // can be evenly split into 11 size chunks; next divisor splits to 1 size chunks } TEST_CASE(split_dim_with_max_splits) { // Test with explicit max_splits constraint - // Note: max_splits is NOT a hard cap - function returns smallest split factor > max_splits that evenly divides dimension + // Note: max_splits is NOT a hard cap - function returns smallest split factor > max_splits that + // evenly divides dimension // When split factor would exceed max_splits, returns next valid divisor - std::size_t dim100a = 100; + std::size_t dim100a = 100; std::size_t result100a = migraphx::split_dim(dim100a, 10, 4); - EXPECT(result100a == 4); // 100 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + EXPECT(result100a == + 4); // 100 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 - std::size_t dim100b = 100; + std::size_t dim100b = 100; std::size_t result100b = migraphx::split_dim(dim100b, 10, 2); EXPECT(result100b == 2); // Max splits doesn't force splitting if min_size constraint would be violated - std::size_t dim20 = 20; + std::size_t dim20 = 20; std::size_t result20 = migraphx::split_dim(dim20, 10, 4); - EXPECT(result20 == 2); // Can only split to 2 (20/2=10, not > 10) + EXPECT(result20 == 2); // Can only split to 2 (20/2=10, not > 10) - std::size_t dim15 = 15; + std::size_t dim15 = 15; std::size_t result15 = migraphx::split_dim(dim15, 10, 4); - EXPECT(result15 == 3); // 15 = 3*5, splits to 3, remaining = 5 + EXPECT(result15 == 3); // 15 = 3*5, splits to 3, remaining = 5 // Test with powers of 2 - std::size_t dim128a = 128; + std::size_t dim128a = 128; std::size_t result128a = migraphx::split_dim(dim128a, 10, 8); - EXPECT(result128a == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 + EXPECT( + result128a == + 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 - std::size_t dim128b = 128; + std::size_t dim128b = 128; std::size_t result128b = migraphx::split_dim(dim128b, 10, 4); - EXPECT(result128b == 4); // 128 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + EXPECT(result128b == + 4); // 128 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 - std::size_t dim128c = 128; + std::size_t dim128c = 128; std::size_t result128c = migraphx::split_dim(dim128c, 20, 8); - EXPECT(result128c == 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 + EXPECT( + result128c == + 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 } TEST_CASE(split_dim_edge_cases) @@ -249,39 +250,39 @@ TEST_CASE(split_dim_edge_cases) // Test edge cases // Very small dimensions - std::size_t dim1 = 1; + std::size_t dim1 = 1; std::size_t result1 = migraphx::split_dim(dim1, 0); - EXPECT(result1 == 1); // 1 can't be split + EXPECT(result1 == 1); // 1 can't be split - std::size_t dim2a = 2; + std::size_t dim2a = 2; std::size_t result2a = migraphx::split_dim(dim2a, 0); - EXPECT(result2a == 2); // 2/2 = 1 > 0 + EXPECT(result2a == 2); // 2/2 = 1 > 0 - std::size_t dim2b = 2; + std::size_t dim2b = 2; std::size_t result2b = migraphx::split_dim(dim2b, 1); - EXPECT(result2b == 2); // 2/2 = 1, but we continue while r > min_size, so 2 > 1 allows split + EXPECT(result2b == 2); // 2/2 = 1, but we continue while r > min_size, so 2 > 1 allows split // Exact boundary conditions - std::size_t dim20a = 20; + std::size_t dim20a = 20; std::size_t result20a = migraphx::split_dim(dim20a, 9); - EXPECT(result20a == 4); // 20 = 2^2 * 5, factors 2,2 before 20/4=5 <= 9 + EXPECT(result20a == 4); // 20 = 2^2 * 5, factors 2,2 before 20/4=5 <= 9 - std::size_t dim20b = 20; + std::size_t dim20b = 20; std::size_t result20b = migraphx::split_dim(dim20b, 10); - EXPECT(result20b == 2); // 20 = 2^2 * 5, factors 2 before 20/2=10 <= 10 + EXPECT(result20b == 2); // 20 = 2^2 * 5, factors 2 before 20/2=10 <= 10 - std::size_t dim21 = 21; + std::size_t dim21 = 21; std::size_t result21 = migraphx::split_dim(dim21, 10); - EXPECT(result21 == 3); // 21 = 3*7, splits by 3 first, 21/3 = 7 <= 10 + EXPECT(result21 == 3); // 21 = 3*7, splits by 3 first, 21/3 = 7 <= 10 // Large prime numbers - std::size_t dim97 = 97; + std::size_t dim97 = 97; std::size_t result97 = migraphx::split_dim(dim97, 10); - EXPECT(result97 == 1); // 97 is prime + EXPECT(result97 == 1); // 97 is prime - std::size_t dim101 = 101; + std::size_t dim101 = 101; std::size_t result101 = migraphx::split_dim(dim101, 10); - EXPECT(result101 == 1); // 101 is prime + EXPECT(result101 == 1); // 101 is prime } TEST_CASE(split_dim_factorization_order) @@ -291,22 +292,22 @@ TEST_CASE(split_dim_factorization_order) // 60 = 2^2 * 3 * 5 // With min_size=10: 60->30->15->5 (stops because 5 <= 10) // Factors used: 2, 2, 3 (product = 12) - std::size_t dim60 = 60; + std::size_t dim60 = 60; std::size_t result60 = migraphx::split_dim(dim60, 10); - EXPECT(result60 == 12); // 60/12 = 5 + EXPECT(result60 == 12); // 60/12 = 5 // 210 = 2 * 3 * 5 * 7 // With min_size=20: continues factoring while 210 > 20 // Factors all: 2*3*5*7 = 210, but stops at 2*3*5 = 30 since 210/30 = 7 <= 20 - std::size_t dim210 = 210; + std::size_t dim210 = 210; std::size_t result210 = migraphx::split_dim(dim210, 20); - EXPECT(result210 == 30); // 210/30 = 7 + EXPECT(result210 == 30); // 210/30 = 7 // 462 = 2 * 3 * 7 * 11 // With min_size=30: 462->231->77->11 (stops because 11 <= 30) - std::size_t dim462 = 462; + std::size_t dim462 = 462; std::size_t result462 = migraphx::split_dim(dim462, 30); - EXPECT(result462 == 42); // 462/42 = 11 <= 30 + EXPECT(result462 == 42); // 462/42 = 11 <= 30 } TEST_CASE(split_dim_reduce_use_case) @@ -315,22 +316,23 @@ TEST_CASE(split_dim_reduce_use_case) // These are realistic values that might appear in reduction operations // Large reduction dimension with typical min_size and max_splits - std::size_t dim1024 = 1024; + std::size_t dim1024 = 1024; std::size_t result1024 = migraphx::split_dim(dim1024, 64, 16); - EXPECT(result1024 == 16); // 1024 = 2^10; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + EXPECT(result1024 == + 16); // 1024 = 2^10; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops - std::size_t dim1000 = 1000; + std::size_t dim1000 = 1000; std::size_t result1000 = migraphx::split_dim(dim1000, 64, 16); - EXPECT(result1000 == 40); // 1000 = 2^3 * 5^3; n=8, r=125 > 64, so next would be n=40 > 16 + EXPECT(result1000 == 40); // 1000 = 2^3 * 5^3; n=8, r=125 > 64, so next would be n=40 > 16 // Smaller dimensions - std::size_t dim256 = 256; + std::size_t dim256 = 256; std::size_t result256 = migraphx::split_dim(dim256, 32, 8); - EXPECT(result256 == 8); // 256 = 2^8; would split to 16 (256/16=16), exceeds max_splits=8 + EXPECT(result256 == 8); // 256 = 2^8; would split to 16 (256/16=16), exceeds max_splits=8 - std::size_t dim200 = 200; + std::size_t dim200 = 200; std::size_t result200 = migraphx::split_dim(dim200, 32, 8); - EXPECT(result200 == 8); // 200 = 2^3 * 5^2; would split to 8 (200/8=25), exceeds max_splits=8 + EXPECT(result200 == 8); // 200 = 2^3 * 5^2; would split to 8 (200/8=25), exceeds max_splits=8 } TEST_CASE(split_dim_flash_attention_use_case) @@ -339,33 +341,34 @@ TEST_CASE(split_dim_flash_attention_use_case) // Sequence lengths need to be split for parallel processing // Typical sequence lengths in attention - std::size_t dim2048 = 2048; + std::size_t dim2048 = 2048; std::size_t result2048 = migraphx::split_dim(dim2048, 128, 16); - EXPECT(result2048 == 16); // 2048 = 2^11; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + EXPECT(result2048 == + 16); // 2048 = 2^11; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops - std::size_t dim4096 = 4096; + std::size_t dim4096 = 4096; std::size_t result4096 = migraphx::split_dim(dim4096, 256, 16); - EXPECT(result4096 == 16); // 4096 = 2^12; splits to 16 (4096/16=256) + EXPECT(result4096 == 16); // 4096 = 2^12; splits to 16 (4096/16=256) // Non-aligned sequence lengths // 1536 = 2^9 * 3; would continue past max_splits=16 to get 1536/24=64 < 128 - std::size_t dim1536 = 1536; + std::size_t dim1536 = 1536; std::size_t result1536 = migraphx::split_dim(dim1536, 128, 16); - EXPECT(result1536 == 16); // 1536 = 2^9 * 3; n=16, r=128, stops (r not > 128) + EXPECT(result1536 == 16); // 1536 = 2^9 * 3; n=16, r=128, stops (r not > 128) // 3000 = 2^3 * 3 * 5^3 - std::size_t dim3000 = 3000; + std::size_t dim3000 = 3000; std::size_t result3000 = migraphx::split_dim(dim3000, 200, 16); - EXPECT(result3000 == 24); // 3*5 = 24, 3000/24 = 125 + EXPECT(result3000 == 24); // 3*5 = 24, 3000/24 = 125 // Smaller sequences - std::size_t dim512 = 512; + std::size_t dim512 = 512; std::size_t result512 = migraphx::split_dim(dim512, 64, 8); - EXPECT(result512 == 8); // 512 = 2^9; stops at n=8 (512/8=64) + EXPECT(result512 == 8); // 512 = 2^9; stops at n=8 (512/8=64) - std::size_t dim768 = 768; + std::size_t dim768 = 768; std::size_t result768 = migraphx::split_dim(dim768, 64, 8); - EXPECT(result768 == 8); // 768 = 2^8 * 3; n=8, r=128 > 64, stops (r not > 128) + EXPECT(result768 == 8); // 768 = 2^8 * 3; n=8, r=128 > 64, stops (r not > 128) } TEST_CASE(split_dim_no_limit_vs_explicit_max) @@ -374,35 +377,35 @@ TEST_CASE(split_dim_no_limit_vs_explicit_max) std::size_t large_max = std::numeric_limits::max(); // These should produce identical results - std::size_t dim1000a = 1000; + std::size_t dim1000a = 1000; std::size_t result1000a = migraphx::split_dim(dim1000a, 10); - std::size_t dim1000b = 1000; + std::size_t dim1000b = 1000; std::size_t result1000b = migraphx::split_dim(dim1000b, 10, large_max); EXPECT(result1000a == result1000b); - std::size_t dim512a = 512; + std::size_t dim512a = 512; std::size_t result512a = migraphx::split_dim(dim512a, 32); - std::size_t dim512b = 512; + std::size_t dim512b = 512; std::size_t result512b = migraphx::split_dim(dim512b, 32, large_max); EXPECT(result512a == result512b); - std::size_t dim360a = 360; + std::size_t dim360a = 360; std::size_t result360a = migraphx::split_dim(dim360a, 20); - std::size_t dim360b = 360; + std::size_t dim360b = 360; std::size_t result360b = migraphx::split_dim(dim360b, 20, large_max); EXPECT(result360a == result360b); // Test that max_splits affects the result (but doesn't necessarily limit it) // max_splits acts as a threshold - result may exceed it - std::size_t dim1000c = 1000; + std::size_t dim1000c = 1000; std::size_t result1000c = migraphx::split_dim(dim1000c, 10); - std::size_t dim1000d = 1000; + std::size_t dim1000d = 1000; std::size_t result1000d = migraphx::split_dim(dim1000d, 10, 8); EXPECT(result1000c != result1000d); - std::size_t dim512c = 512; + std::size_t dim512c = 512; std::size_t result512c = migraphx::split_dim(dim512c, 8); - std::size_t dim512d = 512; + std::size_t dim512d = 512; std::size_t result512d = migraphx::split_dim(dim512d, 8, 16); EXPECT(result512c != result512d); } @@ -418,7 +421,7 @@ TEST_CASE(split_dim_consistency_check) for(std::size_t min_size : {8, 16, 32, 64}) { std::size_t dim_copy = dim; - std::size_t splits = migraphx::split_dim(dim_copy, min_size); + std::size_t splits = migraphx::split_dim(dim_copy, min_size); if(splits > 1) { // The algorithm continues while r > min_size, @@ -438,7 +441,7 @@ TEST_CASE(split_dim_consistency_check) for(std::size_t max_splits : {4, 8, 16}) { std::size_t dim_copy = dim; - std::size_t splits = migraphx::split_dim(dim_copy, 10, max_splits); + std::size_t splits = migraphx::split_dim(dim_copy, 10, max_splits); // Result should evenly divide dimension EXPECT(dim % splits == 0); // Result should make remaining size < min_size (10) @@ -450,7 +453,7 @@ TEST_CASE(split_dim_consistency_check) for(std::size_t dim : {256, 512, 1024}) { std::size_t dim_copy1 = dim; - std::size_t splits_8 = migraphx::split_dim(dim_copy1, 8); + std::size_t splits_8 = migraphx::split_dim(dim_copy1, 8); std::size_t dim_copy2 = dim; std::size_t splits_16 = migraphx::split_dim(dim_copy2, 16); std::size_t dim_copy3 = dim; From dd7eb98d85bc6689bccba1318fb2fba192a5aedc Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 23:13:02 -0600 Subject: [PATCH 19/30] cursor made test cases to handle drop in codecov --- test/fuse_attention.cpp | 104 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 41b145fc267..869fe99289b 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -31,20 +31,124 @@ #include #include #include +#include +#include +#include #include #include #include #include #include #include +#include MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FLASH_DECODING); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); static void run_pass(migraphx::program& p, migraphx::fuse_attention fa = {}) { migraphx::run_passes(p, {fa, migraphx::dead_code_elimination{}}); } +// Test helper functions used in fuse_attention pass +TEST_CASE(get_num_splits_from_member) +{ + // Test that member variable takes precedence over environment variable + migraphx::fuse_attention fa; + fa.flash_decoding_num_splits = 8; + + // This test would require exposing get_num_splits as a public function or friend function + // For now, we test it indirectly through the pass behavior + EXPECT(fa.flash_decoding_num_splits.has_value()); + EXPECT(*fa.flash_decoding_num_splits == 8); +} + +TEST_CASE(calculate_flash_decoding_splits_basic) +{ + // Test the calculate_flash_decoding_splits function indirectly through split_dim + // Since it's essentially a wrapper around split_dim, we test the behavior + + // Test case 1: sequence_length that can be split evenly + // 256 with min_chunk=32 should split to 8 (256/8 = 32) + std::size_t seq_len1 = 256; + std::size_t result1 = migraphx::split_dim(seq_len1, 32, 16); + EXPECT(result1 == 8); + + // Test case 2: sequence_length with max_splits constraint + // 1024 with min_chunk=64 and max_splits=8 should be limited to 8 + std::size_t seq_len2 = 1024; + std::size_t result2 = migraphx::split_dim(seq_len2, 64, 8); + EXPECT(result2 == 8); + + // Test case 3: small sequence that shouldn't be split + // 32 with min_chunk=32 should return 1 (no split) + std::size_t seq_len3 = 32; + std::size_t result3 = migraphx::split_dim(seq_len3, 32, 16); + EXPECT(result3 == 1); + + // Test case 4: prime number sequence length + // 97 with min_chunk=10 should return 1 (can't split prime) + std::size_t seq_len4 = 97; + std::size_t result4 = migraphx::split_dim(seq_len4, 10, 16); + EXPECT(result4 == 1); + + // Test case 5: typical attention sequence lengths + // 2048 with min_chunk=128 and max_splits=16 + std::size_t seq_len5 = 2048; + std::size_t result5 = migraphx::split_dim(seq_len5, 128, 16); + EXPECT(result5 == 16); +} + + +TEST_CASE(padding_calculation) +{ + // Test padding calculation for flash decoding + // When sequence_length is not evenly divisible by actual_groups + + // Test case 1: evenly divisible - no padding needed + std::size_t seq_len1 = 256; + std::size_t groups1 = 8; + std::size_t padding1 = 0; + if(seq_len1 % groups1 != 0) + { + padding1 = migraphx::ceil_mul_of(seq_len1, groups1) - seq_len1; + } + EXPECT(padding1 == 0); + + // Test case 2: not evenly divisible - padding needed + std::size_t seq_len2 = 100; + std::size_t groups2 = 8; + std::size_t padding2 = 0; + if(seq_len2 % groups2 != 0) + { + padding2 = migraphx::ceil_mul_of(seq_len2, groups2) - seq_len2; + } + EXPECT(padding2 == 4); // 104 - 100 = 4 + + // Test case 3: sequence length = 127, groups = 16 + std::size_t seq_len3 = 127; + std::size_t groups3 = 16; + std::size_t padding3 = 0; + if(seq_len3 % groups3 != 0) + { + padding3 = migraphx::ceil_mul_of(seq_len3, groups3) - seq_len3; + } + EXPECT(padding3 == 1); // 128 - 127 = 1 + + // Test case 4: large sequence with padding + std::size_t seq_len4 = 2049; + std::size_t groups4 = 32; + std::size_t padding4 = 0; + if(seq_len4 % groups4 != 0) + { + padding4 = migraphx::ceil_mul_of(seq_len4, groups4) - seq_len4; + } + EXPECT(padding4 == 31); // 2080 - 2049 = 31 +} + TEST_CASE(gemm_softmax_gemm) { migraphx::shape s1{migraphx::shape::half_type, {1, 12, 256, 256}}; From 5b726f35c80c0375e43ba5ef6d93b07454d6b740 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 23:20:28 -0600 Subject: [PATCH 20/30] move group calculation into separate function --- src/fuse_attention.cpp | 56 ++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 21 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index ea0f7aa76db..7031010e4e6 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -70,6 +70,39 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, return split_dim(r, min_chunk_size, max_splits); } +// calculate the actual number of groups for flash decoding +// returns 0 if no splitting should be performed +inline std::size_t calculate_groups(std::size_t groups, std::size_t sequence_length) +{ + // if groups is explicitly set and valid, use it + if(groups > 1) + return groups; + + // if groups is 0, auto-calculate based on sequence length + if(groups == 0) + { + // TODO: run experiments to find the optimal values for min_chunk and max_splits + std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); + std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); + std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); + + // skip if sequence is too short + if(sequence_length < threshold) + return 0; + + std::size_t actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); + + // return 0 if auto-calculation determines no splitting needed + if(actual_groups <= 1) + return 0; + + return actual_groups; + } + + // groups == 1 or invalid, no splitting + return 0; +} + // TODO: Write this in matcher.hpp as a general matcher for iterating through inputs inline auto pointwise_inputs() { @@ -483,27 +516,8 @@ struct find_flash_decoding auto k_shape = k_param->get_shape(); std::size_t sequence_length = k_shape.lens().back(); - // Determine actual number of splits to use - std::size_t actual_groups = groups; - if(groups == 0) - { - // TODO: run experiments to find the optimal values for min_chunk and max_splits - std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); - std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); - std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); - - if(sequence_length < threshold) - return; - - actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); - - // skip if auto-calculation determines no splitting needed - if(actual_groups <= 1) - return; - } - - // skip if no actual splitting (actual_groups must be > 1) - if(actual_groups <= 1) + std::size_t actual_groups = calculate_groups(groups, sequence_length); + if(actual_groups == 0) return; // calculate padding if sequence length not evenly divisible From 86ebe6fa196d6c172892da9b37f2b597f61b12da Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Thu, 11 Dec 2025 23:44:18 -0600 Subject: [PATCH 21/30] rework splits/groups math and member vars. need a member var to be able to manually set splits if desired --- src/fuse_attention.cpp | 39 ++++++++++++++------------------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 7031010e4e6..e97ad6a49e0 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -72,7 +72,7 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, // calculate the actual number of groups for flash decoding // returns 0 if no splitting should be performed -inline std::size_t calculate_groups(std::size_t groups, std::size_t sequence_length) +inline std::size_t calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t threshold) { // if groups is explicitly set and valid, use it if(groups > 1) @@ -81,15 +81,14 @@ inline std::size_t calculate_groups(std::size_t groups, std::size_t sequence_len // if groups is 0, auto-calculate based on sequence length if(groups == 0) { - // TODO: run experiments to find the optimal values for min_chunk and max_splits - std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); - std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); - std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); - // skip if sequence is too short if(sequence_length < threshold) return 0; + // TODO: run experiments to find the optimal values for min_chunk and max_splits + std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); + std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); + std::size_t actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); // return 0 if auto-calculation determines no splitting needed @@ -278,8 +277,8 @@ struct find_attention struct find_flash_decoding { - // number of groups (0 means auto-calculate) - std::size_t groups; + // optional number of splits from fuse_attention pass config + std::optional configured_splits; auto matcher() const { @@ -516,7 +515,11 @@ struct find_flash_decoding auto k_shape = k_param->get_shape(); std::size_t sequence_length = k_shape.lens().back(); - std::size_t actual_groups = calculate_groups(groups, sequence_length); + // read groups configuration from pass config or environment variable + std::size_t groups = get_num_splits(configured_splits); + std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); + + std::size_t actual_groups = calculate_groups(groups, sequence_length, threshold); if(actual_groups == 0) return; @@ -888,25 +891,11 @@ void fuse_attention::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); } - // Apply flash decoding if enabled - bool flash_enabled = false; - std::size_t num_splits = 0; - - std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); - // enable flash decoding if splits configured or explicitly enabled + std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); if(configured_splits > 0 or is_flash_decoding_enabled()) { - flash_enabled = true; - - // 0 means auto-calculate - num_splits = configured_splits > 0 ? configured_splits : 0; - } - - if(flash_enabled) - { - match::find_matches( - mpm, find_flash_decoding{.groups = num_splits}); + match::find_matches(mpm, find_flash_decoding{.configured_splits = flash_decoding_num_splits}); mpm.run_pass(dead_code_elimination{}); } } From 3705dedc0a0f2ca7e7c4b06df94d4a12f3465472 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Fri, 12 Dec 2025 11:47:29 -0600 Subject: [PATCH 22/30] fix busted cusor tests, clang-format & tidy --- src/fuse_attention.cpp | 13 +++++--- src/include/migraphx/split_factor.hpp | 3 +- test/fuse_attention.cpp | 45 ++++++++++----------------- test/math_utils_test.cpp | 6 ++-- 4 files changed, 29 insertions(+), 38 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index e97ad6a49e0..6ca6028793d 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -72,7 +72,8 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, // calculate the actual number of groups for flash decoding // returns 0 if no splitting should be performed -inline std::size_t calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t threshold) +inline std::size_t +calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t threshold) { // if groups is explicitly set and valid, use it if(groups > 1) @@ -89,7 +90,8 @@ inline std::size_t calculate_groups(std::size_t groups, std::size_t sequence_len std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); - std::size_t actual_groups = calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); + std::size_t actual_groups = + calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); // return 0 if auto-calculation determines no splitting needed if(actual_groups <= 1) @@ -516,7 +518,7 @@ struct find_flash_decoding std::size_t sequence_length = k_shape.lens().back(); // read groups configuration from pass config or environment variable - std::size_t groups = get_num_splits(configured_splits); + std::size_t groups = get_num_splits(configured_splits); std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); std::size_t actual_groups = calculate_groups(groups, sequence_length, threshold); @@ -528,7 +530,7 @@ struct find_flash_decoding if(sequence_length % actual_groups != 0) { // round up to nearest multiple of actual_groups - padding_needed = ceil_mul_of(sequence_length, actual_groups) - sequence_length; + padding_needed = ceil_mul_of(sequence_length, actual_groups) - sequence_length; } // create mapping from submodule params to main module inputs @@ -895,7 +897,8 @@ void fuse_attention::apply(module_pass_manager& mpm) const std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); if(configured_splits > 0 or is_flash_decoding_enabled()) { - match::find_matches(mpm, find_flash_decoding{.configured_splits = flash_decoding_num_splits}); + match::find_matches(mpm, + find_flash_decoding{.configured_splits = flash_decoding_num_splits}); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index 2571425b482..e8c5236d91d 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -43,7 +43,8 @@ inline namespace MIGRAPHX_INLINE_NS { * To compute the number of split groups it finds the largest * divisor that can divide dimension to make it less than min_size. * - * @param r The value to split. This is passed by reference and will be modified to the remaining value after splitting. + * @param r The value to split. This is passed by reference and will be modified to the remaining + * value after splitting. * @param min_size Target threshold - splits until remaining size is less than this value * @param max_splits Target threshold - if reached, returns the smallest split factor greater than * or equal to max_splits that evenly divides dimension. Optional diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 869fe99289b..e7fc5e9deba 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -74,35 +74,34 @@ TEST_CASE(calculate_flash_decoding_splits_basic) // Test case 1: sequence_length that can be split evenly // 256 with min_chunk=32 should split to 8 (256/8 = 32) std::size_t seq_len1 = 256; - std::size_t result1 = migraphx::split_dim(seq_len1, 32, 16); + std::size_t result1 = migraphx::split_dim(seq_len1, 32, 16); EXPECT(result1 == 8); // Test case 2: sequence_length with max_splits constraint // 1024 with min_chunk=64 and max_splits=8 should be limited to 8 std::size_t seq_len2 = 1024; - std::size_t result2 = migraphx::split_dim(seq_len2, 64, 8); + std::size_t result2 = migraphx::split_dim(seq_len2, 64, 8); EXPECT(result2 == 8); // Test case 3: small sequence that shouldn't be split // 32 with min_chunk=32 should return 1 (no split) std::size_t seq_len3 = 32; - std::size_t result3 = migraphx::split_dim(seq_len3, 32, 16); + std::size_t result3 = migraphx::split_dim(seq_len3, 32, 16); EXPECT(result3 == 1); // Test case 4: prime number sequence length // 97 with min_chunk=10 should return 1 (can't split prime) std::size_t seq_len4 = 97; - std::size_t result4 = migraphx::split_dim(seq_len4, 10, 16); + std::size_t result4 = migraphx::split_dim(seq_len4, 10, 16); EXPECT(result4 == 1); // Test case 5: typical attention sequence lengths // 2048 with min_chunk=128 and max_splits=16 std::size_t seq_len5 = 2048; - std::size_t result5 = migraphx::split_dim(seq_len5, 128, 16); + std::size_t result5 = migraphx::split_dim(seq_len5, 128, 16); EXPECT(result5 == 16); } - TEST_CASE(padding_calculation) { // Test padding calculation for flash decoding @@ -110,42 +109,30 @@ TEST_CASE(padding_calculation) // Test case 1: evenly divisible - no padding needed std::size_t seq_len1 = 256; - std::size_t groups1 = 8; + std::size_t groups1 = 8; + // 256 % 8 == 0, so no padding needed std::size_t padding1 = 0; - if(seq_len1 % groups1 != 0) - { - padding1 = migraphx::ceil_mul_of(seq_len1, groups1) - seq_len1; - } EXPECT(padding1 == 0); // Test case 2: not evenly divisible - padding needed std::size_t seq_len2 = 100; - std::size_t groups2 = 8; - std::size_t padding2 = 0; - if(seq_len2 % groups2 != 0) - { - padding2 = migraphx::ceil_mul_of(seq_len2, groups2) - seq_len2; - } + std::size_t groups2 = 8; + // 100 % 8 != 0, so padding is needed + std::size_t padding2 = migraphx::ceil_mul_of(seq_len2, groups2) - seq_len2; EXPECT(padding2 == 4); // 104 - 100 = 4 // Test case 3: sequence length = 127, groups = 16 std::size_t seq_len3 = 127; - std::size_t groups3 = 16; - std::size_t padding3 = 0; - if(seq_len3 % groups3 != 0) - { - padding3 = migraphx::ceil_mul_of(seq_len3, groups3) - seq_len3; - } + std::size_t groups3 = 16; + // 127 % 16 != 0, so padding is needed + std::size_t padding3 = migraphx::ceil_mul_of(seq_len3, groups3) - seq_len3; EXPECT(padding3 == 1); // 128 - 127 = 1 // Test case 4: large sequence with padding std::size_t seq_len4 = 2049; - std::size_t groups4 = 32; - std::size_t padding4 = 0; - if(seq_len4 % groups4 != 0) - { - padding4 = migraphx::ceil_mul_of(seq_len4, groups4) - seq_len4; - } + std::size_t groups4 = 32; + // 2049 % 32 != 0, so padding is needed + std::size_t padding4 = migraphx::ceil_mul_of(seq_len4, groups4) - seq_len4; EXPECT(padding4 == 31); // 2080 - 2049 = 31 } diff --git a/test/math_utils_test.cpp b/test/math_utils_test.cpp index 3750213e170..2e8d40310dc 100644 --- a/test/math_utils_test.cpp +++ b/test/math_utils_test.cpp @@ -127,10 +127,10 @@ TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) // Verify that ceil_mul_of(x, y) == y * integer_divide_ceil(x, y) // This tests the implementation relationship - std::size_t test_cases[][2] = { + const std::size_t test_cases[][2] = { {10, 3}, {15, 4}, {100, 7}, {256, 32}, {1000, 13}, {1, 10}, {0, 5}}; - for(const auto& tc : test_cases) + for(const auto* const tc : test_cases) { std::size_t x = tc[0]; std::size_t y = tc[1]; @@ -445,7 +445,7 @@ TEST_CASE(split_dim_consistency_check) // Result should evenly divide dimension EXPECT(dim % splits == 0); // Result should make remaining size < min_size (10) - EXPECT(dim / splits < 10 || splits >= max_splits); + EXPECT(dim / splits < 10 or splits >= max_splits); } } From 8c888ccb55fcb5cc71937ce14fcde4ef687bf246 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 17 Dec 2025 02:17:34 -0600 Subject: [PATCH 23/30] remove stupid tests that cursor keeps adding back --- test/fuse_attention.cpp | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index e7fc5e9deba..b5c35d129f7 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -68,34 +68,31 @@ TEST_CASE(get_num_splits_from_member) TEST_CASE(calculate_flash_decoding_splits_basic) { - // Test the calculate_flash_decoding_splits function indirectly through split_dim - // Since it's essentially a wrapper around split_dim, we test the behavior - - // Test case 1: sequence_length that can be split evenly + // sequence_length that can be split evenly // 256 with min_chunk=32 should split to 8 (256/8 = 32) std::size_t seq_len1 = 256; std::size_t result1 = migraphx::split_dim(seq_len1, 32, 16); EXPECT(result1 == 8); - // Test case 2: sequence_length with max_splits constraint + // sequence_length with max_splits constraint // 1024 with min_chunk=64 and max_splits=8 should be limited to 8 std::size_t seq_len2 = 1024; std::size_t result2 = migraphx::split_dim(seq_len2, 64, 8); EXPECT(result2 == 8); - // Test case 3: small sequence that shouldn't be split + // small sequence that shouldn't be split // 32 with min_chunk=32 should return 1 (no split) std::size_t seq_len3 = 32; std::size_t result3 = migraphx::split_dim(seq_len3, 32, 16); EXPECT(result3 == 1); - // Test case 4: prime number sequence length + // prime number sequence length // 97 with min_chunk=10 should return 1 (can't split prime) std::size_t seq_len4 = 97; std::size_t result4 = migraphx::split_dim(seq_len4, 10, 16); EXPECT(result4 == 1); - // Test case 5: typical attention sequence lengths + // typical attention sequence lengths // 2048 with min_chunk=128 and max_splits=16 std::size_t seq_len5 = 2048; std::size_t result5 = migraphx::split_dim(seq_len5, 128, 16); @@ -104,31 +101,21 @@ TEST_CASE(calculate_flash_decoding_splits_basic) TEST_CASE(padding_calculation) { - // Test padding calculation for flash decoding - // When sequence_length is not evenly divisible by actual_groups - - // Test case 1: evenly divisible - no padding needed - std::size_t seq_len1 = 256; - std::size_t groups1 = 8; - // 256 % 8 == 0, so no padding needed - std::size_t padding1 = 0; - EXPECT(padding1 == 0); - - // Test case 2: not evenly divisible - padding needed + // not evenly divisible - padding needed std::size_t seq_len2 = 100; std::size_t groups2 = 8; // 100 % 8 != 0, so padding is needed std::size_t padding2 = migraphx::ceil_mul_of(seq_len2, groups2) - seq_len2; EXPECT(padding2 == 4); // 104 - 100 = 4 - // Test case 3: sequence length = 127, groups = 16 + // sequence length = 127, groups = 16 std::size_t seq_len3 = 127; std::size_t groups3 = 16; // 127 % 16 != 0, so padding is needed std::size_t padding3 = migraphx::ceil_mul_of(seq_len3, groups3) - seq_len3; EXPECT(padding3 == 1); // 128 - 127 = 1 - // Test case 4: large sequence with padding + // large sequence with padding std::size_t seq_len4 = 2049; std::size_t groups4 = 32; // 2049 % 32 != 0, so padding is needed From 211156322813b34f3e0412771c1676e930181fad Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 29 Dec 2025 15:19:40 -0600 Subject: [PATCH 24/30] code cov tests; refactor a bit --- src/fuse_attention.cpp | 57 ++- src/include/migraphx/fuse_attention.hpp | 6 +- src/targets/gpu/fuse_mlir.cpp | 22 + .../gpu/include/migraphx/gpu/fuse_mlir.hpp | 1 + src/targets/gpu/target.cpp | 3 +- test/fuse_attention.cpp | 475 +++++++++++++++--- 6 files changed, 466 insertions(+), 98 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 6ca6028793d..8e850079b82 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -40,21 +40,32 @@ inline namespace MIGRAPHX_INLINE_NS { namespace { // env vars for flash decoding configuration -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_ENABLED); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); -bool is_flash_decoding_enabled() { return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); } +// Helper function to get config value with priority: struct member (if not default) > env var > default +template +std::size_t get_config_value(std::size_t struct_value, std::size_t default_value, EnvVar env_var) +{ + // if struct member is not the default value, use it + if(struct_value != default_value) + { + return struct_value; + } + + // otherwise return env var value, or default if not set + return value_of(env_var, default_value); +} // Get num_splits with priority: struct member > env var > 0 (not set) -std::size_t get_num_splits(const std::optional& member_num_splits) +std::size_t get_num_splits(std::size_t member_num_splits) { - // struct member var is used for testing - if(member_num_splits.has_value()) + // if struct member is set (non-zero), use it + if(member_num_splits > 0) { - return *member_num_splits; + return member_num_splits; } // otherwise return env var value, or 0 if not set @@ -73,7 +84,7 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, // calculate the actual number of groups for flash decoding // returns 0 if no splitting should be performed inline std::size_t -calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t threshold) +calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t threshold, std::size_t min_chunk_size, std::size_t max_splits) { // if groups is explicitly set and valid, use it if(groups > 1) @@ -86,12 +97,8 @@ calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t th if(sequence_length < threshold) return 0; - // TODO: run experiments to find the optimal values for min_chunk and max_splits - std::size_t min_chunk = value_of(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}, 32); - std::size_t max_splits = value_of(MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}, 16); - std::size_t actual_groups = - calculate_flash_decoding_splits(sequence_length, min_chunk, max_splits); + calculate_flash_decoding_splits(sequence_length, min_chunk_size, max_splits); // return 0 if auto-calculation determines no splitting needed if(actual_groups <= 1) @@ -279,8 +286,11 @@ struct find_attention struct find_flash_decoding { - // optional number of splits from fuse_attention pass config - std::optional configured_splits; + // configuration from fuse_attention pass config + std::size_t configured_splits; + std::size_t configured_threshold; + std::size_t configured_max_splits; + std::size_t configured_min_chunk_size; auto matcher() const { @@ -517,11 +527,13 @@ struct find_flash_decoding auto k_shape = k_param->get_shape(); std::size_t sequence_length = k_shape.lens().back(); - // read groups configuration from pass config or environment variable - std::size_t groups = get_num_splits(configured_splits); - std::size_t threshold = value_of(MIGRAPHX_FLASH_DECODING_THRESHOLD{}, 32); + // read configuration with priority: struct member (if not default) > env var > default + std::size_t groups = get_num_splits(configured_splits); + std::size_t threshold = get_config_value(configured_threshold, 32, MIGRAPHX_FLASH_DECODING_THRESHOLD{}); + std::size_t min_chunk_size = get_config_value(configured_min_chunk_size, 32, MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}); + std::size_t max_splits = get_config_value(configured_max_splits, 16, MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}); - std::size_t actual_groups = calculate_groups(groups, sequence_length, threshold); + std::size_t actual_groups = calculate_groups(groups, sequence_length, threshold, min_chunk_size, max_splits); if(actual_groups == 0) return; @@ -893,12 +905,15 @@ void fuse_attention::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); } - // enable flash decoding if splits configured or explicitly enabled + // enable flash decoding if splits configured or flash decoding is enabled std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); - if(configured_splits > 0 or is_flash_decoding_enabled()) + if(configured_splits > 0 or flash_decoding_enabled) { match::find_matches(mpm, - find_flash_decoding{.configured_splits = flash_decoding_num_splits}); + find_flash_decoding{.configured_splits = flash_decoding_num_splits, + .configured_threshold = flash_decoding_threshold, + .configured_max_splits = flash_decoding_max_splits, + .configured_min_chunk_size = flash_decoding_min_chunk_size}); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/fuse_attention.hpp b/src/include/migraphx/fuse_attention.hpp index db07c9aac02..923fec72bc4 100644 --- a/src/include/migraphx/fuse_attention.hpp +++ b/src/include/migraphx/fuse_attention.hpp @@ -38,7 +38,11 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_attention { bool attn_enabled = false; - std::optional flash_decoding_num_splits = std::nullopt; + bool flash_decoding_enabled = false; + std::size_t flash_decoding_num_splits = 0; + std::size_t flash_decoding_threshold = 32; + std::size_t flash_decoding_max_splits = 16; + std::size_t flash_decoding_min_chunk_size = 32; std::string name() const { return "fuse_attention"; } void apply(module_pass_manager& mpm) const; diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index cf3cf6f3416..4cbb8766744 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -48,6 +48,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_ENABLED); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** @@ -143,6 +144,27 @@ bool mlir_attention_enabled(context* ctx) #endif } +bool mlir_flash_decoding_enabled(context* ctx) +{ +#ifdef MIGRAPHX_MLIR + if(not mlir_enabled()) + return false; + + // Check if explicitly enabled via environment variable + if(enabled(MIGRAPHX_FLASH_DECODING_ENABLED{})) + return true; + + // Enable flash decoding by default for mi300 (when MLIR is available) + if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94")) + return true; + + return false; +#else + // Without MLIR, only enable if explicitly requested via env var + return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); +#endif +} + #ifdef MIGRAPHX_MLIR struct mlir_op diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index 82edc35e581..51ec4275f79 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -35,6 +35,7 @@ namespace gpu { MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled(context* ctx); +MIGRAPHX_GPU_EXPORT bool mlir_flash_decoding_enabled(context* ctx); MIGRAPHX_GPU_EXPORT bool mlir_geg_multi_user_intermediates_supported(); struct MIGRAPHX_GPU_EXPORT fuse_mlir diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index de2da57ebc4..8eb08d83a4c 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -225,7 +225,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, simplify_reshapes{.enable_op_shape_transform_op=true}, dead_code_elimination{}, - enable_pass(mlir_enabled(), fuse_attention{mlir_attention_enabled(&ctx)}), + enable_pass(mlir_enabled(), fuse_attention{.attn_enabled = mlir_attention_enabled(&ctx), + .flash_decoding_enabled = mlir_flash_decoding_enabled(&ctx)}), dead_code_elimination{}, optimize_module{}, enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index b5c35d129f7..134f973afda 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -42,11 +42,6 @@ #include #include -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_FLASH_DECODING); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); static void run_pass(migraphx::program& p, migraphx::fuse_attention fa = {}) { @@ -60,10 +55,11 @@ TEST_CASE(get_num_splits_from_member) migraphx::fuse_attention fa; fa.flash_decoding_num_splits = 8; - // This test would require exposing get_num_splits as a public function or friend function - // For now, we test it indirectly through the pass behavior - EXPECT(fa.flash_decoding_num_splits.has_value()); - EXPECT(*fa.flash_decoding_num_splits == 8); + // Test that struct members are set correctly + EXPECT(fa.flash_decoding_num_splits == 8); + EXPECT(fa.flash_decoding_threshold == 32); // default value + EXPECT(fa.flash_decoding_max_splits == 16); // default value + EXPECT(fa.flash_decoding_min_chunk_size == 32); // default value } TEST_CASE(calculate_flash_decoding_splits_basic) @@ -523,7 +519,7 @@ TEST_CASE(gemm_softmax_gemm_flash_decoding) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 2}); + run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -624,7 +620,7 @@ TEST_CASE(flash_decoding_3d) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 2}); + run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { @@ -713,70 +709,6 @@ TEST_CASE(flash_decoding_3d) EXPECT(p1.sort() == p2.sort()); } -TEST_CASE(flash_decoding_disabled) -{ - migraphx::shape s1{migraphx::shape::half_type, {1, 12, 256, 256}}; - - migraphx::program p1; - { - auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("1", s1); - auto b = mm->add_parameter("2", s1); - auto b1 = mm->add_parameter("3", s1); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - b1); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); - rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), - rmax); - auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); - rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), - rsum); - auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); - mm->add_return({gemm2}); - } - run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 0}); - - // Expected result: only attention fusion, no flash decoding - migraphx::program p2; - { - auto* mm = p2.get_main_module(); - auto a = mm->add_parameter("1", s1); - auto b = mm->add_parameter("2", s1); - auto b1 = mm->add_parameter("3", s1); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - b1); - auto group = add_group( - p2, - "attn0", - "attention", - {a, b, b1}, - {"x0", "x1", "x2"}, - [=](auto* gm, const auto& inputs) { - auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); - auto rmax = - gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); - rmax = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax); - auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); - auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); - auto rsum = - gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); - rsum = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); - auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum); - auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); - return std::vector{gemm2}; - }); - mm->add_return({group}); - } - EXPECT(p1.sort() == p2.sort()); -} TEST_CASE(kv_cache_attention) { @@ -956,6 +888,399 @@ TEST_CASE(kv_cache_attention) EXPECT(p1.sort() == p2.sort()); } +// Test automatic splitting with num_splits = 0 (auto-calculate) +TEST_CASE(flash_decoding_3d_auto_split_large_sequence) +{ + // 3D Shape: [batch, sequence_length, head_dim] - Use larger sequence to trigger auto-splitting + migraphx::shape s_3d{migraphx::shape::half_type, {1, 512, 512}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Use auto-splitting: num_splits = 0, with sequence length 512 > threshold 32 + run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); + + // Expected program with automatic splitting (should calculate 16 splits for 512 sequence) + const std::size_t expected_splits = 16; // 512 = 2^9, split until chunk = 32, so 512/16 = 32 + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + size_t g_axis = 1; + + // New shapes for flash decoding with calculated splits + std::vector q_prime_shape = {1, expected_splits, 512, 512}; + std::vector k_prime_shape = {1, expected_splits, 512, 32}; // 512/16 = 32 + std::vector v_prime_shape = {1, expected_splits, 32, 512}; + + auto a_unsqueeze = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {g_axis}}}), a); + auto a_broadcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); + + auto b_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + // K: [1, 512, 512] -> [1, 512, 16, 32] -> [1, 16, 512, 32] + auto b_reshape_intermediate = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 512, expected_splits, 32}}}), b_transpose); + auto b_reshape = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + b_reshape_intermediate); + + auto b1_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto b1_reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1_transpose); + + auto group = add_group( + p2, + "attn0_flash_decoding", + "attention", + {a_broadcast, b_reshape, b1_reshape}, + {"x0", "x1", "x2"}, + [&](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 32}}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 32}}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + auto log = gm->add_instruction(migraphx::make_op("log"), rsum); + auto add = gm->add_instruction(migraphx::make_op("add"), rmax, log); + return std::vector{gemm2, add}; + }); + auto o_p = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), group); + auto lse = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), group); + + // Kernel 2 + auto k2_rmax = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); + auto k2_broad1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), k2_rmax); + auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); + auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); + auto k2_rsum1 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); + auto k2_broad2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), k2_rsum1); + auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); + auto k2_broad3 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); + auto k2_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), k2_broad3); + auto k2_mul = mm->add_instruction(migraphx::make_op("mul"), o_p, k2_convert); + auto k2_rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_mul); + auto k2_squeeze = + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {g_axis}}}), k2_rsum2); + mm->add_return({k2_squeeze}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(flash_decoding_3d_auto_split_small_sequence) +{ + // 3D Shape: [batch, sequence_length, head_dim] - Small sequence that should NOT trigger splitting + migraphx::shape s_3d{migraphx::shape::half_type, {1, 16, 16}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Use auto-splitting: num_splits = 0, with small sequence length 16 < threshold 32 + run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); + + // Expected program with regular attention (no flash decoding for small sequence) + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto group = add_group( + p2, + "attn0", + "attention", + {a, b, b1}, + {"x0", "x1", "x2"}, + [=](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + return std::vector{gemm2}; + }); + mm->add_return({group}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(flash_decoding_4d_auto_split_custom_params) +{ + // 4D Shape: [batch, heads, sequence_length, head_dim] - Test with custom parameters + migraphx::shape s1{migraphx::shape::half_type, {1, 12, 256, 256}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("1", s1); + auto b = mm->add_parameter("2", s1); + auto b1 = mm->add_parameter("3", s1); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Test with custom min_chunk_size and max_splits + run_pass(p1, {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, // Auto-calculate + .flash_decoding_threshold = 32, + .flash_decoding_max_splits = 4, // Smaller max splits + .flash_decoding_min_chunk_size = 64}); // Larger chunk size + + // Check for flash decoding + bool found_flash_decoding = false; + for(auto ins : *p1.get_main_module()) + { + if(ins.name().find("group") != std::string::npos) + { + found_flash_decoding = true; + break; + } + } + + EXPECT(found_flash_decoding); +} + +TEST_CASE(flash_decoding_auto_split_threshold_behavior) +{ + // Test threshold behavior - sequence right at the threshold boundary + migraphx::shape s_3d{migraphx::shape::half_type, {1, 127, 127}}; + + migraphx::program p1, p2; + + // Test 1: sequence length below threshold - should NOT split + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + run_pass(p1, {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, + .flash_decoding_threshold = 128, // Greater than sequence length (127) + .flash_decoding_max_splits = 8, + .flash_decoding_min_chunk_size = 32}); + + // Test 2: sequence length at threshold - should split + migraphx::shape s_3d_larger{migraphx::shape::half_type, {1, 128, 128}}; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d_larger); + auto b = mm->add_parameter("k", s_3d_larger); + auto b1 = mm->add_parameter("v", s_3d_larger); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + run_pass(p2, {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, + .flash_decoding_threshold = 128, // Equal to sequence length (128) + .flash_decoding_max_splits = 8, + .flash_decoding_min_chunk_size = 32}); + + // Check results - look for flash decoding by checking module names + bool found_flash_decoding_p1 = false, found_flash_decoding_p2 = false; + bool found_regular_attention_p1 = false; + + for(auto ins : *p1.get_main_module()) + { + if(ins.name().find("group") != std::string::npos) + { + // Check the module name to distinguish flash decoding from regular attention + auto module_inputs = ins.module_inputs(); + if(!module_inputs.empty()) + { + auto mod_name = module_inputs[0]->name(); + if(mod_name.find("flash_decoding") != std::string::npos) + { + found_flash_decoding_p1 = true; + } + else + { + found_regular_attention_p1 = true; + } + } + } + } + + for(auto ins : *p2.get_main_module()) + { + if(ins.name().find("group") != std::string::npos) + { + auto module_inputs = ins.module_inputs(); + if(!module_inputs.empty()) + { + auto mod_name = module_inputs[0]->name(); + if(mod_name.find("flash_decoding") != std::string::npos) + { + found_flash_decoding_p2 = true; + } + } + } + } + + // Below threshold: should have regular attention, not flash decoding + EXPECT(not found_flash_decoding_p1); + EXPECT(found_regular_attention_p1); // Should have regular attention instead + // At threshold: should have flash decoding + EXPECT(found_flash_decoding_p2); +} + +TEST_CASE(flash_decoding_auto_split_max_splits_constraint) +{ + // Test that max_splits constraint is respected + migraphx::shape s_3d{migraphx::shape::half_type, {1, 2048, 2048}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Use small max_splits to test constraint + run_pass(p1, {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, // Auto-calculate + .flash_decoding_threshold = 32, + .flash_decoding_max_splits = 4, // Small max_splits + .flash_decoding_min_chunk_size = 64}); + + // Check that flash decoding was applied + bool found_flash_decoding = false; + for(auto ins : *p1.get_main_module()) + { + if(ins.name().find("group") != std::string::npos) + { + found_flash_decoding = true; + break; + } + } + + EXPECT(found_flash_decoding); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); From 26324c8ed6685d2b8651f28ba6c687aeba654c97 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Mon, 29 Dec 2025 15:35:56 -0600 Subject: [PATCH 25/30] format --- src/fuse_attention.cpp | 35 +++++---- src/include/migraphx/fuse_attention.hpp | 8 +- test/fuse_attention.cpp | 97 ++++++++++++++----------- 3 files changed, 81 insertions(+), 59 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 8e850079b82..7b23c78ffd7 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -45,8 +45,9 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); -// Helper function to get config value with priority: struct member (if not default) > env var > default -template +// Helper function to get config value with priority: struct member (if not default) > env var > +// default +template std::size_t get_config_value(std::size_t struct_value, std::size_t default_value, EnvVar env_var) { // if struct member is not the default value, use it @@ -83,8 +84,11 @@ inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, // calculate the actual number of groups for flash decoding // returns 0 if no splitting should be performed -inline std::size_t -calculate_groups(std::size_t groups, std::size_t sequence_length, std::size_t threshold, std::size_t min_chunk_size, std::size_t max_splits) +inline std::size_t calculate_groups(std::size_t groups, + std::size_t sequence_length, + std::size_t threshold, + std::size_t min_chunk_size, + std::size_t max_splits) { // if groups is explicitly set and valid, use it if(groups > 1) @@ -529,11 +533,15 @@ struct find_flash_decoding // read configuration with priority: struct member (if not default) > env var > default std::size_t groups = get_num_splits(configured_splits); - std::size_t threshold = get_config_value(configured_threshold, 32, MIGRAPHX_FLASH_DECODING_THRESHOLD{}); - std::size_t min_chunk_size = get_config_value(configured_min_chunk_size, 32, MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}); - std::size_t max_splits = get_config_value(configured_max_splits, 16, MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}); + std::size_t threshold = + get_config_value(configured_threshold, 32, MIGRAPHX_FLASH_DECODING_THRESHOLD{}); + std::size_t min_chunk_size = get_config_value( + configured_min_chunk_size, 32, MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}); + std::size_t max_splits = + get_config_value(configured_max_splits, 16, MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}); - std::size_t actual_groups = calculate_groups(groups, sequence_length, threshold, min_chunk_size, max_splits); + std::size_t actual_groups = + calculate_groups(groups, sequence_length, threshold, min_chunk_size, max_splits); if(actual_groups == 0) return; @@ -909,11 +917,12 @@ void fuse_attention::apply(module_pass_manager& mpm) const std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); if(configured_splits > 0 or flash_decoding_enabled) { - match::find_matches(mpm, - find_flash_decoding{.configured_splits = flash_decoding_num_splits, - .configured_threshold = flash_decoding_threshold, - .configured_max_splits = flash_decoding_max_splits, - .configured_min_chunk_size = flash_decoding_min_chunk_size}); + match::find_matches( + mpm, + find_flash_decoding{.configured_splits = flash_decoding_num_splits, + .configured_threshold = flash_decoding_threshold, + .configured_max_splits = flash_decoding_max_splits, + .configured_min_chunk_size = flash_decoding_min_chunk_size}); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/fuse_attention.hpp b/src/include/migraphx/fuse_attention.hpp index 923fec72bc4..4ad67d644f4 100644 --- a/src/include/migraphx/fuse_attention.hpp +++ b/src/include/migraphx/fuse_attention.hpp @@ -38,10 +38,10 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_attention { bool attn_enabled = false; - bool flash_decoding_enabled = false; - std::size_t flash_decoding_num_splits = 0; - std::size_t flash_decoding_threshold = 32; - std::size_t flash_decoding_max_splits = 16; + bool flash_decoding_enabled = false; + std::size_t flash_decoding_num_splits = 0; + std::size_t flash_decoding_threshold = 32; + std::size_t flash_decoding_max_splits = 16; std::size_t flash_decoding_min_chunk_size = 32; std::string name() const { return "fuse_attention"; } diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 134f973afda..4f5b85389ab 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -42,7 +42,6 @@ #include #include - static void run_pass(migraphx::program& p, migraphx::fuse_attention fa = {}) { migraphx::run_passes(p, {fa, migraphx::dead_code_elimination{}}); @@ -57,8 +56,8 @@ TEST_CASE(get_num_splits_from_member) // Test that struct members are set correctly EXPECT(fa.flash_decoding_num_splits == 8); - EXPECT(fa.flash_decoding_threshold == 32); // default value - EXPECT(fa.flash_decoding_max_splits == 16); // default value + EXPECT(fa.flash_decoding_threshold == 32); // default value + EXPECT(fa.flash_decoding_max_splits == 16); // default value EXPECT(fa.flash_decoding_min_chunk_size == 32); // default value } @@ -519,7 +518,8 @@ TEST_CASE(gemm_softmax_gemm_flash_decoding) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -620,7 +620,8 @@ TEST_CASE(flash_decoding_3d) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { @@ -709,7 +710,6 @@ TEST_CASE(flash_decoding_3d) EXPECT(p1.sort() == p2.sort()); } - TEST_CASE(kv_cache_attention) { migraphx::shape s1{migraphx::shape::half_type, {1}}; @@ -916,7 +916,8 @@ TEST_CASE(flash_decoding_3d_auto_split_large_sequence) mm->add_return({gemm2}); } // Use auto-splitting: num_splits = 0, with sequence length 512 > threshold 32 - run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); // Expected program with automatic splitting (should calculate 16 splits for 512 sequence) const std::size_t expected_splits = 16; // 512 = 2^9, split until chunk = 32, so 512/16 = 32 @@ -930,7 +931,7 @@ TEST_CASE(flash_decoding_3d_auto_split_large_sequence) // New shapes for flash decoding with calculated splits std::vector q_prime_shape = {1, expected_splits, 512, 512}; - std::vector k_prime_shape = {1, expected_splits, 512, 32}; // 512/16 = 32 + std::vector k_prime_shape = {1, expected_splits, 512, 32}; // 512/16 = 32 std::vector v_prime_shape = {1, expected_splits, 32, 512}; auto a_unsqueeze = @@ -963,14 +964,16 @@ TEST_CASE(flash_decoding_3d_auto_split_large_sequence) auto rmax = gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); auto rmax_broad = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 32}}}), + migraphx::make_op("multibroadcast", + {{"out_lens", {1, expected_splits, 512, 32}}}), rmax); auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); auto rsum = gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); auto rsum_broad = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 32}}}), + migraphx::make_op("multibroadcast", + {{"out_lens", {1, expected_splits, 512, 32}}}), rsum); auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); @@ -985,13 +988,15 @@ TEST_CASE(flash_decoding_3d_auto_split_large_sequence) auto k2_rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); auto k2_broad1 = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), k2_rmax); + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), + k2_rmax); auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); auto k2_rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); auto k2_broad2 = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), k2_rsum1); + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), + k2_rsum1); auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); auto k2_broad3 = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); @@ -1009,7 +1014,8 @@ TEST_CASE(flash_decoding_3d_auto_split_large_sequence) TEST_CASE(flash_decoding_3d_auto_split_small_sequence) { - // 3D Shape: [batch, sequence_length, head_dim] - Small sequence that should NOT trigger splitting + // 3D Shape: [batch, sequence_length, head_dim] - Small sequence that should NOT trigger + // splitting migraphx::shape s_3d{migraphx::shape::half_type, {1, 16, 16}}; migraphx::program p1; @@ -1034,7 +1040,8 @@ TEST_CASE(flash_decoding_3d_auto_split_small_sequence) mm->add_return({gemm2}); } // Use auto-splitting: num_splits = 0, with small sequence length 16 < threshold 32 - run_pass(p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); // Expected program with regular attention (no flash decoding for small sequence) migraphx::program p2; @@ -1043,7 +1050,7 @@ TEST_CASE(flash_decoding_3d_auto_split_small_sequence) auto a = mm->add_parameter("q", s_3d); auto b = mm->add_parameter("k", s_3d); auto b1 = mm->add_parameter("v", s_3d); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); auto group = add_group( p2, @@ -1056,13 +1063,15 @@ TEST_CASE(flash_decoding_3d_auto_split_small_sequence) auto rmax = gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); auto rmax_broad = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rmax); auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); auto rsum = gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); auto rsum_broad = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rsum); auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); return std::vector{gemm2}; @@ -1100,12 +1109,13 @@ TEST_CASE(flash_decoding_4d_auto_split_custom_params) mm->add_return({gemm2}); } // Test with custom min_chunk_size and max_splits - run_pass(p1, {.attn_enabled = true, - .flash_decoding_enabled = true, - .flash_decoding_num_splits = 0, // Auto-calculate - .flash_decoding_threshold = 32, - .flash_decoding_max_splits = 4, // Smaller max splits - .flash_decoding_min_chunk_size = 64}); // Larger chunk size + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, // Auto-calculate + .flash_decoding_threshold = 32, + .flash_decoding_max_splits = 4, // Smaller max splits + .flash_decoding_min_chunk_size = 64}); // Larger chunk size // Check for flash decoding bool found_flash_decoding = false; @@ -1149,12 +1159,13 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, - .flash_decoding_enabled = true, - .flash_decoding_num_splits = 0, - .flash_decoding_threshold = 128, // Greater than sequence length (127) - .flash_decoding_max_splits = 8, - .flash_decoding_min_chunk_size = 32}); + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, + .flash_decoding_threshold = 128, // Greater than sequence length (127) + .flash_decoding_max_splits = 8, + .flash_decoding_min_chunk_size = 32}); // Test 2: sequence length at threshold - should split migraphx::shape s_3d_larger{migraphx::shape::half_type, {1, 128, 128}}; @@ -1178,12 +1189,13 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p2, {.attn_enabled = true, - .flash_decoding_enabled = true, - .flash_decoding_num_splits = 0, - .flash_decoding_threshold = 128, // Equal to sequence length (128) - .flash_decoding_max_splits = 8, - .flash_decoding_min_chunk_size = 32}); + run_pass(p2, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, + .flash_decoding_threshold = 128, // Equal to sequence length (128) + .flash_decoding_max_splits = 8, + .flash_decoding_min_chunk_size = 32}); // Check results - look for flash decoding by checking module names bool found_flash_decoding_p1 = false, found_flash_decoding_p2 = false; @@ -1228,7 +1240,7 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) // Below threshold: should have regular attention, not flash decoding EXPECT(not found_flash_decoding_p1); - EXPECT(found_regular_attention_p1); // Should have regular attention instead + EXPECT(found_regular_attention_p1); // Should have regular attention instead // At threshold: should have flash decoding EXPECT(found_flash_decoding_p2); } @@ -1260,12 +1272,13 @@ TEST_CASE(flash_decoding_auto_split_max_splits_constraint) mm->add_return({gemm2}); } // Use small max_splits to test constraint - run_pass(p1, {.attn_enabled = true, - .flash_decoding_enabled = true, - .flash_decoding_num_splits = 0, // Auto-calculate - .flash_decoding_threshold = 32, - .flash_decoding_max_splits = 4, // Small max_splits - .flash_decoding_min_chunk_size = 64}); + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, // Auto-calculate + .flash_decoding_threshold = 32, + .flash_decoding_max_splits = 4, // Small max_splits + .flash_decoding_min_chunk_size = 64}); // Check that flash decoding was applied bool found_flash_decoding = false; From de3eb6b7c4aac89869d99b681046f85adc342d7d Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Tue, 30 Dec 2025 11:26:06 -0600 Subject: [PATCH 26/30] tidy --- test/fuse_attention.cpp | 50 +++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 12 deletions(-) diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 4f5b85389ab..ba1290f4402 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -1119,7 +1119,7 @@ TEST_CASE(flash_decoding_4d_auto_split_custom_params) // Check for flash decoding bool found_flash_decoding = false; - for(auto ins : *p1.get_main_module()) + for(const auto& ins : *p1.get_main_module()) { if(ins.name().find("group") != std::string::npos) { @@ -1136,7 +1136,8 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) // Test threshold behavior - sequence right at the threshold boundary migraphx::shape s_3d{migraphx::shape::half_type, {1, 127, 127}}; - migraphx::program p1, p2; + migraphx::program p1; + migraphx::program p2; // Test 1: sequence length below threshold - should NOT split { @@ -1198,18 +1199,19 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) .flash_decoding_min_chunk_size = 32}); // Check results - look for flash decoding by checking module names - bool found_flash_decoding_p1 = false, found_flash_decoding_p2 = false; + bool found_flash_decoding_p1 = false; + bool found_flash_decoding_p2 = false; bool found_regular_attention_p1 = false; - for(auto ins : *p1.get_main_module()) + for(const auto& ins : *p1.get_main_module()) { if(ins.name().find("group") != std::string::npos) { // Check the module name to distinguish flash decoding from regular attention - auto module_inputs = ins.module_inputs(); - if(!module_inputs.empty()) + const auto& module_inputs = ins.module_inputs(); + if(not module_inputs.empty()) { - auto mod_name = module_inputs[0]->name(); + const auto& mod_name = module_inputs[0]->name(); if(mod_name.find("flash_decoding") != std::string::npos) { found_flash_decoding_p1 = true; @@ -1222,14 +1224,14 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) } } - for(auto ins : *p2.get_main_module()) + for(const auto& ins : *p2.get_main_module()) { if(ins.name().find("group") != std::string::npos) { - auto module_inputs = ins.module_inputs(); - if(!module_inputs.empty()) + const auto& module_inputs = ins.module_inputs(); + if(not module_inputs.empty()) { - auto mod_name = module_inputs[0]->name(); + const auto& mod_name = module_inputs[0]->name(); if(mod_name.find("flash_decoding") != std::string::npos) { found_flash_decoding_p2 = true; @@ -1282,7 +1284,7 @@ TEST_CASE(flash_decoding_auto_split_max_splits_constraint) // Check that flash decoding was applied bool found_flash_decoding = false; - for(auto ins : *p1.get_main_module()) + for(const auto& ins : *p1.get_main_module()) { if(ins.name().find("group") != std::string::npos) { @@ -1294,6 +1296,30 @@ TEST_CASE(flash_decoding_auto_split_max_splits_constraint) EXPECT(found_flash_decoding); } +TEST_CASE(ceil_mul_of_function) +{ + // Test the ceil_mul_of function used for padding calculations + + // Test exact multiples - no padding needed + EXPECT(migraphx::ceil_mul_of(16, 4) == 16); // 16 is multiple of 4 + EXPECT(migraphx::ceil_mul_of(32, 8) == 32); // 32 is multiple of 8 + EXPECT(migraphx::ceil_mul_of(100, 10) == 100); // 100 is multiple of 10 + + // Test non-multiples - padding needed + EXPECT(migraphx::ceil_mul_of(17, 4) == 20); // 17 -> 20 (next multiple of 4) + EXPECT(migraphx::ceil_mul_of(33, 8) == 40); // 33 -> 40 (next multiple of 8) + EXPECT(migraphx::ceil_mul_of(101, 10) == 110); // 101 -> 110 (next multiple of 10) + + // Test edge cases + EXPECT(migraphx::ceil_mul_of(1, 4) == 4); // 1 -> 4 + EXPECT(migraphx::ceil_mul_of(0, 4) == 0); // 0 -> 0 + + // Test specific cases from attention fusion + EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 100 -> 104 (padding = 4) + EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 127 -> 128 (padding = 1) + EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31) +} + int main(int argc, const char* argv[]) { test::run(argc, argv); From 9cfd66409d5b3f30afdeb3aeb3f4acec4b41fbc0 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 31 Dec 2025 14:34:15 -0600 Subject: [PATCH 27/30] add more tests for codecov and fix bug in padding --- src/fuse_attention.cpp | 41 ++++++- test/fuse_attention.cpp | 264 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 290 insertions(+), 15 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 7b23c78ffd7..ae3db67a958 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -562,9 +562,20 @@ struct find_flash_decoding auto k = map_param_to_main.at(k_param); auto v = map_param_to_main.at(v_param); - // pad K and V if necessary + // save original references before padding (needed for group_inputs replacement later) + auto q_orig = q; + auto k_orig = k; + auto v_orig = v; + + // pad Q, K and V if necessary if(padding_needed > 0) { + // Q shape: [B, M, k] or [B, H, M, k] for 4D. Padding on M (sequence length dim) + auto q_ndim = q->get_shape().ndim(); + std::vector q_pads(2 * q_ndim, 0); + q_pads[q_ndim + q_ndim - 2] = padding_needed; // pad right on M dim (second to last) + q = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", q_pads}}), q); + // K shape: [B, k, N] or [B, H, k, N] for 4D. Padding on N auto k_ndim = k->get_shape().ndim(); std::vector k_pads(2 * k_ndim, 0); @@ -609,18 +620,19 @@ struct find_flash_decoding attn_group_ins, make_op("reshape", {{"dims", transform_info.v_shape}}), v); // create new input list by replacing Q, K, V with reshaped versions + // use original references (before padding) for comparison std::vector new_group_inputs = group_inputs; for(size_t i = 0; i < group_inputs.size(); ++i) { - if(group_inputs[i] == q) + if(group_inputs[i] == q_orig) { new_group_inputs[i] = q_reshaped; } - else if(group_inputs[i] == k) + else if(group_inputs[i] == k_orig) { new_group_inputs[i] = k_reshaped; } - else if(group_inputs[i] == v) + else if(group_inputs[i] == v_orig) { new_group_inputs[i] = v_reshaped; } @@ -720,8 +732,27 @@ struct find_flash_decoding auto final_squeezed_o = mm.insert_instruction( attn_group_ins, make_op("squeeze", {{"axes", {g_axis}}}), final_output_o); + // if padding was applied, slice to remove it + instruction_ref final_result = final_squeezed_o; + if(padding_needed > 0) + { + // need to slice the sequence dimension to remove padding + // final_squeezed_o has shape like [B, M_padded, D], need to slice M back to original + auto output_shape = final_squeezed_o->get_shape(); + auto output_lens = output_shape.lens(); + std::size_t seq_dim_idx = output_lens.size() - 2; // sequence dim is second to last + std::size_t original_seq_len = output_lens[seq_dim_idx] - padding_needed; + + final_result = mm.insert_instruction( + attn_group_ins, + make_op("slice", {{"axes", {seq_dim_idx}}, + {"starts", {0}}, + {"ends", {original_seq_len}}}), + final_squeezed_o); + } + // replace the original group instruction with the final result - mm.replace_instruction(attn_group_ins, final_squeezed_o); + mm.replace_instruction(attn_group_ins, final_result); } }; diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index ba1290f4402..0f29a2373a3 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -710,6 +710,250 @@ TEST_CASE(flash_decoding_3d) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(flash_decoding_3d_rectangular) +{ + // 3D Shape: [batch, head_dim, sequence_length] + migraphx::shape s_3d{migraphx::shape::half_type, {1, 256, 240}}; + migraphx::shape st_3d{migraphx::shape::half_type, {1, 240, 256}}; + const std::size_t num_splits = 2; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); // [1, 256, 240] + auto b = mm->add_parameter("k", s_3d); // [1, 256, 240] + auto b1 = mm->add_parameter("v", st_3d); // [1, 240, 256] + a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); // [1, 240, 256] + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // [1, 240, 256] x [1, 256, 240] = [1, 240, 240] + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); // [1, 240, 1] + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); // [1, 240, 240] + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // [1, 240, 240] + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 240, 240] + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); // [1, 240, 1] + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); // [1, 240, 240] + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // [1, 240, 240] + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // [1, 240, 240] x [1, 240, 256] = [1, 240, 256] + mm->add_return({gemm2}); + } + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = num_splits}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", st_3d); + size_t g_axis = 1; + + // New shapes for flash decoding - 240 split into 2x120 + std::vector q_prime_shape = {1, num_splits, 240, 256}; + std::vector k_prime_shape = {1, num_splits, 256, 120}; + std::vector v_prime_shape = {1, num_splits, 120, 256}; + + auto a_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); + auto a_unsqueeze = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {g_axis}}}), a_transpose); + auto a_broadcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); + + // K: [1, 256, 240] -> [1, 256, 2, 120] -> [1, 2, 256, 120] + auto b_reshape_intermediate = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 256, 2, 120}}}), b); + auto b_reshape = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + b_reshape_intermediate); + + // V: [1, 240, 256] -> [1, 2, 120, 256] -> [1, 2, 120, 256] + auto b1_reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1); + + auto group = add_group( + p2, + "attn0_flash_decoding", + "attention", + {a_broadcast, b_reshape, b1_reshape}, + {"x0", "x1", "x2"}, + [&](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 120}}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 120}}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + auto log = gm->add_instruction(migraphx::make_op("log"), rsum); + auto add = gm->add_instruction(migraphx::make_op("add"), rmax, log); + return std::vector{gemm2, add}; + }); + auto o_p = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), group); + auto lse = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), group); + + // Kernel 2 + auto k2_rmax = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); + auto k2_broad1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 1}}}), k2_rmax); + auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); + auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); + auto k2_rsum1 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); + auto k2_broad2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 1}}}), k2_rsum1); + auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); + auto k2_broad3 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); + auto k2_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), k2_broad3); + auto k2_mul = mm->add_instruction(migraphx::make_op("mul"), o_p, k2_convert); + auto k2_rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_mul); + auto k2_squeeze = + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {g_axis}}}), k2_rsum2); + mm->add_return({k2_squeeze}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(flash_decoding_3d_padding) +{ + // 3D Shape: [batch, head_dim, sequence_length] + migraphx::shape s_3d{migraphx::shape::half_type, {1, 256, 241}}; + migraphx::shape st_3d{migraphx::shape::half_type, {1, 241, 256}}; + const std::size_t num_splits = 2; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); // [1, 256, 241] + auto b = mm->add_parameter("k", s_3d); // [1, 256, 241] + auto b1 = mm->add_parameter("v", st_3d); // [1, 241, 256] + a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); // [1, 241, 256] + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // [1, 241, 256] x [1, 256, 241] = [1, 241, 241] + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); // [1, 241, 1] + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); // [1, 241, 241] + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // [1, 241, 241] + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 241, 241] + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); // [1, 241, 1] + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); // [1, 241, 241] + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // [1, 241, 241] + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // [1, 241, 241] x [1, 241, 256] = [1, 241, 256] + mm->add_return({gemm2}); + } + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = num_splits}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", st_3d); + size_t g_axis = 1; + + // Padding 241 -> 242, split into 2x121 + std::vector q_prime_shape = {1, num_splits, 242, 256}; + std::vector k_prime_shape = {1, num_splits, 256, 121}; + std::vector v_prime_shape = {1, num_splits, 121, 256}; + + // Q: [1, 256, 241] -> transpose -> [1, 241, 256] -> pad -> [1, 242, 256] + auto a_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); + auto a_padded = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 1, 0}}}), a_transpose); + auto a_unsqueeze = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {g_axis}}}), a_padded); + auto a_broadcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); + + // K: [1, 256, 241] -> pad -> [1, 256, 242] -> reshape -> [1, 256, 2, 121] -> transpose -> [1, 2, 256, 121] + auto b_padded = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 1}}}), b); + auto b_reshape_intermediate = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 256, 2, 121}}}), b_padded); + auto b_reshape = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + b_reshape_intermediate); + + // V: [1, 241, 256] -> pad -> [1, 242, 256] -> reshape -> [1, 2, 121, 256] + auto b1_padded = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 1, 0}}}), b1); + auto b1_reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1_padded); + + auto group = add_group( + p2, + "attn0_flash_decoding", + "attention", + {a_broadcast, b_reshape, b1_reshape}, + {"x0", "x1", "x2"}, + [&](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 121}}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 121}}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + auto log = gm->add_instruction(migraphx::make_op("log"), rsum); + auto add = gm->add_instruction(migraphx::make_op("add"), rmax, log); + return std::vector{gemm2, add}; + }); + auto o_p = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), group); + auto lse = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), group); + + // Kernel 2 + auto k2_rmax = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); + auto k2_broad1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 1}}}), k2_rmax); + auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); + auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); + auto k2_rsum1 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); + auto k2_broad2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 1}}}), k2_rsum1); + auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); + auto k2_broad3 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); + auto k2_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), k2_broad3); + auto k2_mul = mm->add_instruction(migraphx::make_op("mul"), o_p, k2_convert); + auto k2_rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_mul); + auto k2_squeeze = + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {g_axis}}}), k2_rsum2); + + // Slice to remove padding: [1, 242, 256] -> [1, 241, 256] + auto sliced = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {241}}}), k2_squeeze); + + mm->add_return({sliced}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(kv_cache_attention) { migraphx::shape s1{migraphx::shape::half_type, {1}}; @@ -1199,8 +1443,8 @@ TEST_CASE(flash_decoding_auto_split_threshold_behavior) .flash_decoding_min_chunk_size = 32}); // Check results - look for flash decoding by checking module names - bool found_flash_decoding_p1 = false; - bool found_flash_decoding_p2 = false; + bool found_flash_decoding_p1 = false; + bool found_flash_decoding_p2 = false; bool found_regular_attention_p1 = false; for(const auto& ins : *p1.get_main_module()) @@ -1301,22 +1545,22 @@ TEST_CASE(ceil_mul_of_function) // Test the ceil_mul_of function used for padding calculations // Test exact multiples - no padding needed - EXPECT(migraphx::ceil_mul_of(16, 4) == 16); // 16 is multiple of 4 - EXPECT(migraphx::ceil_mul_of(32, 8) == 32); // 32 is multiple of 8 + EXPECT(migraphx::ceil_mul_of(16, 4) == 16); // 16 is multiple of 4 + EXPECT(migraphx::ceil_mul_of(32, 8) == 32); // 32 is multiple of 8 EXPECT(migraphx::ceil_mul_of(100, 10) == 100); // 100 is multiple of 10 // Test non-multiples - padding needed - EXPECT(migraphx::ceil_mul_of(17, 4) == 20); // 17 -> 20 (next multiple of 4) - EXPECT(migraphx::ceil_mul_of(33, 8) == 40); // 33 -> 40 (next multiple of 8) + EXPECT(migraphx::ceil_mul_of(17, 4) == 20); // 17 -> 20 (next multiple of 4) + EXPECT(migraphx::ceil_mul_of(33, 8) == 40); // 33 -> 40 (next multiple of 8) EXPECT(migraphx::ceil_mul_of(101, 10) == 110); // 101 -> 110 (next multiple of 10) // Test edge cases - EXPECT(migraphx::ceil_mul_of(1, 4) == 4); // 1 -> 4 - EXPECT(migraphx::ceil_mul_of(0, 4) == 0); // 0 -> 0 + EXPECT(migraphx::ceil_mul_of(1, 4) == 4); // 1 -> 4 + EXPECT(migraphx::ceil_mul_of(0, 4) == 0); // 0 -> 0 // Test specific cases from attention fusion - EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 100 -> 104 (padding = 4) - EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 127 -> 128 (padding = 1) + EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 100 -> 104 (padding = 4) + EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 127 -> 128 (padding = 1) EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31) } From 7173c41fcecb4542319961e7bf463b97cbb24029 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 31 Dec 2025 16:04:08 -0600 Subject: [PATCH 28/30] format --- src/fuse_attention.cpp | 11 ++-- test/fuse_attention.cpp | 109 +++++++++++++++++++++++----------------- 2 files changed, 67 insertions(+), 53 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index ae3db67a958..1455c8e58a5 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -738,16 +738,15 @@ struct find_flash_decoding { // need to slice the sequence dimension to remove padding // final_squeezed_o has shape like [B, M_padded, D], need to slice M back to original - auto output_shape = final_squeezed_o->get_shape(); - auto output_lens = output_shape.lens(); - std::size_t seq_dim_idx = output_lens.size() - 2; // sequence dim is second to last + auto output_shape = final_squeezed_o->get_shape(); + const auto& output_lens = output_shape.lens(); + std::size_t seq_dim_idx = output_lens.size() - 2; // sequence dim is second to last std::size_t original_seq_len = output_lens[seq_dim_idx] - padding_needed; final_result = mm.insert_instruction( attn_group_ins, - make_op("slice", {{"axes", {seq_dim_idx}}, - {"starts", {0}}, - {"ends", {original_seq_len}}}), + make_op("slice", + {{"axes", {seq_dim_idx}}, {"starts", {0}}, {"ends", {original_seq_len}}}), final_squeezed_o); } diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 0f29a2373a3..562d63e7842 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -720,25 +720,34 @@ TEST_CASE(flash_decoding_3d_rectangular) migraphx::program p1; { auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("q", s_3d); // [1, 256, 240] - auto b = mm->add_parameter("k", s_3d); // [1, 256, 240] + auto a = mm->add_parameter("q", s_3d); // [1, 256, 240] + auto b = mm->add_parameter("k", s_3d); // [1, 256, 240] auto b1 = mm->add_parameter("v", st_3d); // [1, 240, 256] - a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); // [1, 240, 256] - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // [1, 240, 256] x [1, 256, 240] = [1, 240, 240] - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); // [1, 240, 1] - rmax = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); // [1, 240, 240] + a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), + a); // [1, 240, 256] + auto gemm1 = mm->add_instruction( + migraphx::make_op("dot"), a, b); // [1, 240, 256] x [1, 256, 240] = [1, 240, 240] + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), + gemm1); // [1, 240, 1] + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rmax); // [1, 240, 240] auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // [1, 240, 240] - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 240, 240] - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); // [1, 240, 1] + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 240, 240] + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), + exp); // [1, 240, 1] rsum = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); // [1, 240, 240] + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rsum); // [1, 240, 240] auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // [1, 240, 240] - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // [1, 240, 240] x [1, 240, 256] = [1, 240, 256] + auto gemm2 = mm->add_instruction( + migraphx::make_op("dot"), div, b1); // [1, 240, 240] x [1, 240, 256] = [1, 240, 256] mm->add_return({gemm2}); } - run_pass( - p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = num_splits}); + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = num_splits}); migraphx::program p2; { @@ -761,15 +770,15 @@ TEST_CASE(flash_decoding_3d_rectangular) migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); // K: [1, 256, 240] -> [1, 256, 2, 120] -> [1, 2, 256, 120] - auto b_reshape_intermediate = mm->add_instruction( - migraphx::make_op("reshape", {{"dims", {1, 256, 2, 120}}}), b); + auto b_reshape_intermediate = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 256, 2, 120}}}), b); auto b_reshape = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), b_reshape_intermediate); // V: [1, 240, 256] -> [1, 2, 120, 256] -> [1, 2, 120, 256] - auto b1_reshape = mm->add_instruction( - migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1); + auto b1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1); auto group = add_group( p2, @@ -836,25 +845,34 @@ TEST_CASE(flash_decoding_3d_padding) migraphx::program p1; { auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("q", s_3d); // [1, 256, 241] - auto b = mm->add_parameter("k", s_3d); // [1, 256, 241] + auto a = mm->add_parameter("q", s_3d); // [1, 256, 241] + auto b = mm->add_parameter("k", s_3d); // [1, 256, 241] auto b1 = mm->add_parameter("v", st_3d); // [1, 241, 256] - a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); // [1, 241, 256] - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // [1, 241, 256] x [1, 256, 241] = [1, 241, 241] - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); // [1, 241, 1] - rmax = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); // [1, 241, 241] + a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), + a); // [1, 241, 256] + auto gemm1 = mm->add_instruction( + migraphx::make_op("dot"), a, b); // [1, 241, 256] x [1, 256, 241] = [1, 241, 241] + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), + gemm1); // [1, 241, 1] + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rmax); // [1, 241, 241] auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // [1, 241, 241] - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 241, 241] - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); // [1, 241, 1] + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 241, 241] + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), + exp); // [1, 241, 1] rsum = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); // [1, 241, 241] + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rsum); // [1, 241, 241] auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // [1, 241, 241] - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // [1, 241, 241] x [1, 241, 256] = [1, 241, 256] + auto gemm2 = mm->add_instruction( + migraphx::make_op("dot"), div, b1); // [1, 241, 241] x [1, 241, 256] = [1, 241, 256] mm->add_return({gemm2}); } - run_pass( - p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = num_splits}); + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = num_splits}); migraphx::program p2; { @@ -879,9 +897,10 @@ TEST_CASE(flash_decoding_3d_padding) auto a_broadcast = mm->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); - // K: [1, 256, 241] -> pad -> [1, 256, 242] -> reshape -> [1, 256, 2, 121] -> transpose -> [1, 2, 256, 121] - auto b_padded = mm->add_instruction( - migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 1}}}), b); + // K: [1, 256, 241] -> pad -> [1, 256, 242] -> reshape -> [1, 256, 2, 121] -> transpose -> + // [1, 2, 256, 121] + auto b_padded = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 1}}}), b); auto b_reshape_intermediate = mm->add_instruction( migraphx::make_op("reshape", {{"dims", {1, 256, 2, 121}}}), b_padded); auto b_reshape = @@ -889,10 +908,10 @@ TEST_CASE(flash_decoding_3d_padding) b_reshape_intermediate); // V: [1, 241, 256] -> pad -> [1, 242, 256] -> reshape -> [1, 2, 121, 256] - auto b1_padded = mm->add_instruction( - migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 1, 0}}}), b1); - auto b1_reshape = mm->add_instruction( - migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1_padded); + auto b1_padded = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 1, 0}}}), b1); + auto b1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1_padded); auto group = add_group( p2, @@ -947,7 +966,8 @@ TEST_CASE(flash_decoding_3d_padding) // Slice to remove padding: [1, 242, 256] -> [1, 241, 256] auto sliced = mm->add_instruction( - migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {241}}}), k2_squeeze); + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {241}}}), + k2_squeeze); mm->add_return({sliced}); } @@ -1527,15 +1547,10 @@ TEST_CASE(flash_decoding_auto_split_max_splits_constraint) .flash_decoding_min_chunk_size = 64}); // Check that flash decoding was applied - bool found_flash_decoding = false; - for(const auto& ins : *p1.get_main_module()) - { - if(ins.name().find("group") != std::string::npos) - { - found_flash_decoding = true; - break; - } - } + auto* mod = p1.get_main_module(); + bool found_flash_decoding = std::any_of(mod->begin(), mod->end(), [](const auto& ins) { + return ins.name().find("group") != std::string::npos; + }); EXPECT(found_flash_decoding); } From b2cd8ad8932761cc3a79132114245e1e54a1c936 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 31 Dec 2025 19:04:49 -0600 Subject: [PATCH 29/30] change a loop, disable flash decoding on mi300 --- src/targets/gpu/fuse_mlir.cpp | 6 +----- src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp | 2 +- src/targets/gpu/target.cpp | 2 +- test/fuse_attention.cpp | 13 ++++--------- 4 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 4cbb8766744..937a5eddc8c 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -144,7 +144,7 @@ bool mlir_attention_enabled(context* ctx) #endif } -bool mlir_flash_decoding_enabled(context* ctx) +bool mlir_flash_decoding_enabled() { #ifdef MIGRAPHX_MLIR if(not mlir_enabled()) @@ -154,10 +154,6 @@ bool mlir_flash_decoding_enabled(context* ctx) if(enabled(MIGRAPHX_FLASH_DECODING_ENABLED{})) return true; - // Enable flash decoding by default for mi300 (when MLIR is available) - if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94")) - return true; - return false; #else // Without MLIR, only enable if explicitly requested via env var diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index 51ec4275f79..06db0614387 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -35,7 +35,7 @@ namespace gpu { MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled(context* ctx); -MIGRAPHX_GPU_EXPORT bool mlir_flash_decoding_enabled(context* ctx); +MIGRAPHX_GPU_EXPORT bool mlir_flash_decoding_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_geg_multi_user_intermediates_supported(); struct MIGRAPHX_GPU_EXPORT fuse_mlir diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 8eb08d83a4c..3a6dd467f84 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -226,7 +226,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti simplify_reshapes{.enable_op_shape_transform_op=true}, dead_code_elimination{}, enable_pass(mlir_enabled(), fuse_attention{.attn_enabled = mlir_attention_enabled(&ctx), - .flash_decoding_enabled = mlir_flash_decoding_enabled(&ctx)}), + .flash_decoding_enabled = mlir_flash_decoding_enabled()}), dead_code_elimination{}, optimize_module{}, enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 562d63e7842..5271d2c0786 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -1382,15 +1382,10 @@ TEST_CASE(flash_decoding_4d_auto_split_custom_params) .flash_decoding_min_chunk_size = 64}); // Larger chunk size // Check for flash decoding - bool found_flash_decoding = false; - for(const auto& ins : *p1.get_main_module()) - { - if(ins.name().find("group") != std::string::npos) - { - found_flash_decoding = true; - break; - } - } + auto* mod = p1.get_main_module(); + bool found_flash_decoding = std::any_of(mod->begin(), mod->end(), [](const auto& ins) { + return ins.name().find("group") != std::string::npos; + }); EXPECT(found_flash_decoding); } From 19c7edee186ae1a9d63bfaf15125a81bed550046 Mon Sep 17 00:00:00 2001 From: Breanna Devore-McDonald Date: Wed, 31 Dec 2025 19:12:05 -0600 Subject: [PATCH 30/30] happy new yeargit add -u! --- src/fuse_attention.cpp | 2 +- src/include/migraphx/fuse_attention.hpp | 2 +- src/include/migraphx/generic_float.hpp | 2 +- src/include/migraphx/split_factor.hpp | 2 +- src/rewrite_topk.cpp | 2 +- src/targets/gpu/fuse_mlir.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp | 2 +- src/targets/gpu/jit/reduce.cpp | 2 +- src/targets/gpu/target.cpp | 2 +- test/fuse_attention.cpp | 2 +- test/math_utils_test.cpp | 2 +- test/verify/test_attention_flash_decoding_3d.cpp | 2 +- test/verify/test_attention_flash_decoding_4d.cpp | 2 +- 13 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 1455c8e58a5..8912b186cdb 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/fuse_attention.hpp b/src/include/migraphx/fuse_attention.hpp index 4ad67d644f4..1b9329b809c 100644 --- a/src/include/migraphx/fuse_attention.hpp +++ b/src/include/migraphx/fuse_attention.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp index c70621956d4..46bb7421793 100644 --- a/src/include/migraphx/generic_float.hpp +++ b/src/include/migraphx/generic_float.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp index e8c5236d91d..cb67695230d 100644 --- a/src/include/migraphx/split_factor.hpp +++ b/src/include/migraphx/split_factor.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/rewrite_topk.cpp b/src/rewrite_topk.cpp index d3075a3bf6b..19411680db8 100644 --- a/src/rewrite_topk.cpp +++ b/src/rewrite_topk.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 937a5eddc8c..c8f0b0f154f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index 06db0614387..715fe1057a3 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index fa417c66480..3e8682d5f37 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3a6dd467f84..891d80631bc 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 5271d2c0786..a662ec89e6e 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/math_utils_test.cpp b/test/math_utils_test.cpp index 2e8d40310dc..e0a3c9099ab 100644 --- a/test/math_utils_test.cpp +++ b/test/math_utils_test.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/verify/test_attention_flash_decoding_3d.cpp b/test/verify/test_attention_flash_decoding_3d.cpp index 068b2be1688..9d71733758d 100644 --- a/test/verify/test_attention_flash_decoding_3d.cpp +++ b/test/verify/test_attention_flash_decoding_3d.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/test/verify/test_attention_flash_decoding_4d.cpp b/test/verify/test_attention_flash_decoding_4d.cpp index 4ab98eeab5b..725cf296ae0 100644 --- a/test/verify/test_attention_flash_decoding_4d.cpp +++ b/test/verify/test_attention_flash_decoding_4d.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal