diff --git a/docs/reference/MIGraphX-dev-env-vars.rst b/docs/reference/MIGraphX-dev-env-vars.rst index b8e84e4c59e..5524ee45cf5 100644 --- a/docs/reference/MIGraphX-dev-env-vars.rst +++ b/docs/reference/MIGraphX-dev-env-vars.rst @@ -168,13 +168,49 @@ Model performance tunable variables change the compilation behavior of a model. | Default: Split-k performance configurations are turned off. + * - | ``MIGRAPHX_FLASH_DECODING_ENABLED`` + - | When set, flash decoding optimization for attention fusion is enabled, which splits the key-value sequence dimension for improved performance on long sequences. + - | ``1``: Enables flash decoding optimization. + | ``0``: Disables flash decoding optimization. + + | Default: ``0`` (disabled). + * - | ``MIGRAPHX_FLASH_DECODING_NUM_SPLITS`` - | Turns on flash decoding for attention fusion and sets the number of splits along the key-value sequence dimension. + | Sets the number of splits along the key-value sequence dimension when flash decoding is enabled. + + - | ``0`` or negative: Automatically calculates optimal splits based on sequence length. + | ``N`` (where N > 0): Uses exactly N splits along the key-value sequence dimension. + + | Default: ``0`` (automatic calculation). + + | Note: This variable implicitly sets ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. If not set, and ``MIGRAPHX_FLASH_DECODING_ENABLED=1``, the number of splits will be automatically calculated. + + * - | ``MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE`` + | Sets the minimum chunk size for automatic split calculation in flash decoding. - - | ``0``: Flash decoding is turned off (i.e., number of splits is 0). - | ``N`` (where N > 1): Enables flash decoding with N splits along the key-value sequence dimension. For example, ``2`` enables flash decoding with 2 splits, ``4`` with 4 splits, etc. + - | Takes a positive integer representing the minimum size of each chunk after splitting. + + | Default: ``32``. + + | Note: Only used when automatic split calculation is enabled. - | Default: flash decoding is turned off. + * - | ``MIGRAPHX_FLASH_DECODING_MAX_SPLITS`` + | Sets the maximum number of splits allowed during automatic split calculation. + + - | Takes a positive integer representing the maximum number of splits. + + | Default: ``16``. + + | Note: Only used when automatic split calculation is enabled. + + * - | ``MIGRAPHX_FLASH_DECODING_THRESHOLD`` + | Sets the minimum sequence length threshold for flash decoding to be applied. + + - | Takes a positive integer. Flash decoding is only applied when the sequence length is greater than or equal to this threshold. + + | Default: ``32``. + + | Note: Only used when automatic split calculation is enabled. * - | ``MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT`` | When set, FP16 is not converted to FP32 in the ``InstanceNormalization`` ONNX operator. diff --git a/src/fuse_attention.cpp b/src/fuse_attention.cpp index 7fc29634c4f..8912b186cdb 100644 --- a/src/fuse_attention.cpp +++ b/src/fuse_attention.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,9 @@ #include #include #include +#include #include +#include #include #include @@ -37,9 +39,81 @@ inline namespace MIGRAPHX_INLINE_NS { namespace { +// env vars for flash decoding configuration MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD); + +// Helper function to get config value with priority: struct member (if not default) > env var > +// default +template +std::size_t get_config_value(std::size_t struct_value, std::size_t default_value, EnvVar env_var) +{ + // if struct member is not the default value, use it + if(struct_value != default_value) + { + return struct_value; + } + + // otherwise return env var value, or default if not set + return value_of(env_var, default_value); +} + +// Get num_splits with priority: struct member > env var > 0 (not set) +std::size_t get_num_splits(std::size_t member_num_splits) +{ + // if struct member is set (non-zero), use it + if(member_num_splits > 0) + { + return member_num_splits; + } + + // otherwise return env var value, or 0 if not set + return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); +} -std::size_t get_num_splits() { return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); } +// calculate optimal flash decoding splits +inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length, + std::size_t min_chunk_size, + std::size_t max_splits) +{ + std::size_t r = sequence_length; + return split_dim(r, min_chunk_size, max_splits); +} + +// calculate the actual number of groups for flash decoding +// returns 0 if no splitting should be performed +inline std::size_t calculate_groups(std::size_t groups, + std::size_t sequence_length, + std::size_t threshold, + std::size_t min_chunk_size, + std::size_t max_splits) +{ + // if groups is explicitly set and valid, use it + if(groups > 1) + return groups; + + // if groups is 0, auto-calculate based on sequence length + if(groups == 0) + { + // skip if sequence is too short + if(sequence_length < threshold) + return 0; + + std::size_t actual_groups = + calculate_flash_decoding_splits(sequence_length, min_chunk_size, max_splits); + + // return 0 if auto-calculation determines no splitting needed + if(actual_groups <= 1) + return 0; + + return actual_groups; + } + + // groups == 1 or invalid, no splitting + return 0; +} // TODO: Write this in matcher.hpp as a general matcher for iterating through inputs inline auto pointwise_inputs() @@ -216,8 +290,11 @@ struct find_attention struct find_flash_decoding { - // number of groups. User-provided for now - std::size_t groups; + // configuration from fuse_attention pass config + std::size_t configured_splits; + std::size_t configured_threshold; + std::size_t configured_max_splits; + std::size_t configured_min_chunk_size; auto matcher() const { @@ -267,7 +344,8 @@ struct find_flash_decoding std::vector v_shape; // final V shape: [B, G, N/G, D] }; - transformed_shapes_result get_transformed_shapes(const std::vector& input_shapes) const + transformed_shapes_result get_transformed_shapes(const std::vector& input_shapes, + std::size_t num_groups) const { assert(input_shapes.size() == 3 and "Expected Q, K, V shapes"); @@ -279,11 +357,11 @@ struct find_flash_decoding // 4D: Q_lens = [B, H, M, k] size_t ndim = q_lens.size(); size_t n = k_lens[ndim - 1]; - size_t g = groups; + size_t g = num_groups; - // TODO: handle uneven splits; this is caught in `apply` for now - assert(n % g == 0 and - "Key-value sequence length must be divisible by number of splits/groups"); + // Note: sequence length may have been padded to be divisible by num_groups + assert(n % g == 0 and "Key-value sequence length must be divisible by number of " + "splits/groups (after padding)"); size_t n_split = n / g; transformed_shapes_result result; @@ -449,15 +527,31 @@ struct find_flash_decoding assert(k_param->name() == "@param" and "K should be a parameter"); assert(v_param->name() == "@param" and "V should be a parameter"); - // check if N dimension is evenly divisible by num_splits - if(k_param->get_shape().lens().back() % groups != 0) + // Get sequence length from K shape + auto k_shape = k_param->get_shape(); + std::size_t sequence_length = k_shape.lens().back(); + + // read configuration with priority: struct member (if not default) > env var > default + std::size_t groups = get_num_splits(configured_splits); + std::size_t threshold = + get_config_value(configured_threshold, 32, MIGRAPHX_FLASH_DECODING_THRESHOLD{}); + std::size_t min_chunk_size = get_config_value( + configured_min_chunk_size, 32, MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{}); + std::size_t max_splits = + get_config_value(configured_max_splits, 16, MIGRAPHX_FLASH_DECODING_MAX_SPLITS{}); + + std::size_t actual_groups = + calculate_groups(groups, sequence_length, threshold, min_chunk_size, max_splits); + if(actual_groups == 0) return; - // get Q, V, K shapes from gemms - auto qkv_shapes = get_qkv_shapes(q_param, k_param, v_param); - - // check shapes are ok and get flash decoding transformed shapes (Q', V', K') - auto transform_info = get_transformed_shapes(qkv_shapes); + // calculate padding if sequence length not evenly divisible + std::size_t padding_needed = 0; + if(sequence_length % actual_groups != 0) + { + // round up to nearest multiple of actual_groups + padding_needed = ceil_mul_of(sequence_length, actual_groups) - sequence_length; + } // create mapping from submodule params to main module inputs auto group_inputs = attn_group_ins->inputs(); @@ -468,6 +562,39 @@ struct find_flash_decoding auto k = map_param_to_main.at(k_param); auto v = map_param_to_main.at(v_param); + // save original references before padding (needed for group_inputs replacement later) + auto q_orig = q; + auto k_orig = k; + auto v_orig = v; + + // pad Q, K and V if necessary + if(padding_needed > 0) + { + // Q shape: [B, M, k] or [B, H, M, k] for 4D. Padding on M (sequence length dim) + auto q_ndim = q->get_shape().ndim(); + std::vector q_pads(2 * q_ndim, 0); + q_pads[q_ndim + q_ndim - 2] = padding_needed; // pad right on M dim (second to last) + q = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", q_pads}}), q); + + // K shape: [B, k, N] or [B, H, k, N] for 4D. Padding on N + auto k_ndim = k->get_shape().ndim(); + std::vector k_pads(2 * k_ndim, 0); + k_pads[k_ndim + k_ndim - 1] = padding_needed; // pad right on last dim + k = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", k_pads}}), k); + + // V shape: [B, N, D] or [B, H, N, D] for 4D + auto v_ndim = v->get_shape().ndim(); + std::vector v_pads(2 * v_ndim, 0); + v_pads[v_ndim + v_ndim - 2] = padding_needed; // pad right on N dim + v = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", v_pads}}), v); + } + + // get Q, K, V shapes (using potentially padded K and V) + auto qkv_shapes = get_qkv_shapes(q, k, v); + + // check shapes are ok and get flash decoding transformed shapes (Q', V', K') + auto transform_info = get_transformed_shapes(qkv_shapes, actual_groups); + // insert reshape operations before group, for Q, K, V auto q_ndim = q->get_shape().lens().size(); int64_t g_axis = q_ndim - 2; @@ -493,18 +620,19 @@ struct find_flash_decoding attn_group_ins, make_op("reshape", {{"dims", transform_info.v_shape}}), v); // create new input list by replacing Q, K, V with reshaped versions + // use original references (before padding) for comparison std::vector new_group_inputs = group_inputs; for(size_t i = 0; i < group_inputs.size(); ++i) { - if(group_inputs[i] == q) + if(group_inputs[i] == q_orig) { new_group_inputs[i] = q_reshaped; } - else if(group_inputs[i] == k) + else if(group_inputs[i] == k_orig) { new_group_inputs[i] = k_reshaped; } - else if(group_inputs[i] == v) + else if(group_inputs[i] == v_orig) { new_group_inputs[i] = v_reshaped; } @@ -604,8 +732,26 @@ struct find_flash_decoding auto final_squeezed_o = mm.insert_instruction( attn_group_ins, make_op("squeeze", {{"axes", {g_axis}}}), final_output_o); + // if padding was applied, slice to remove it + instruction_ref final_result = final_squeezed_o; + if(padding_needed > 0) + { + // need to slice the sequence dimension to remove padding + // final_squeezed_o has shape like [B, M_padded, D], need to slice M back to original + auto output_shape = final_squeezed_o->get_shape(); + const auto& output_lens = output_shape.lens(); + std::size_t seq_dim_idx = output_lens.size() - 2; // sequence dim is second to last + std::size_t original_seq_len = output_lens[seq_dim_idx] - padding_needed; + + final_result = mm.insert_instruction( + attn_group_ins, + make_op("slice", + {{"axes", {seq_dim_idx}}, {"starts", {0}}, {"ends", {original_seq_len}}}), + final_squeezed_o); + } + // replace the original group instruction with the final result - mm.replace_instruction(attn_group_ins, final_squeezed_o); + mm.replace_instruction(attn_group_ins, final_result); } }; @@ -797,20 +943,16 @@ void fuse_attention::apply(module_pass_manager& mpm) const mpm.run_pass(dead_code_elimination{}); } - std::size_t num_splits = 0; - if(flash_decoding_num_splits.has_value()) - { - // Use the value from the constructor (for testing) - num_splits = *flash_decoding_num_splits; - } - else - { - // Default behavior: read from the env var (for non-test use) - num_splits = get_num_splits(); - } - if(num_splits > 1) + // enable flash decoding if splits configured or flash decoding is enabled + std::size_t configured_splits = get_num_splits(flash_decoding_num_splits); + if(configured_splits > 0 or flash_decoding_enabled) { - match::find_matches(mpm, find_flash_decoding{.groups = num_splits}); + match::find_matches( + mpm, + find_flash_decoding{.configured_splits = flash_decoding_num_splits, + .configured_threshold = flash_decoding_threshold, + .configured_max_splits = flash_decoding_max_splits, + .configured_min_chunk_size = flash_decoding_min_chunk_size}); mpm.run_pass(dead_code_elimination{}); } } diff --git a/src/include/migraphx/fuse_attention.hpp b/src/include/migraphx/fuse_attention.hpp index db07c9aac02..1b9329b809c 100644 --- a/src/include/migraphx/fuse_attention.hpp +++ b/src/include/migraphx/fuse_attention.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -38,7 +38,11 @@ struct module_pass_manager; struct MIGRAPHX_EXPORT fuse_attention { bool attn_enabled = false; - std::optional flash_decoding_num_splits = std::nullopt; + bool flash_decoding_enabled = false; + std::size_t flash_decoding_num_splits = 0; + std::size_t flash_decoding_threshold = 32; + std::size_t flash_decoding_max_splits = 16; + std::size_t flash_decoding_min_chunk_size = 32; std::string name() const { return "fuse_attention"; } void apply(module_pass_manager& mpm) const; diff --git a/src/include/migraphx/generic_float.hpp b/src/include/migraphx/generic_float.hpp index b077d16631b..46bb7421793 100644 --- a/src/include/migraphx/generic_float.hpp +++ b/src/include/migraphx/generic_float.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -42,6 +42,13 @@ constexpr std::size_t integer_divide_ceil(std::size_t x, std::size_t y) return (x + y - std::size_t{1}) / y; } +// compute the smallest multiple of y that is greater than or equal to x +// this is equivalent to y * ceil(x / y) +constexpr std::size_t ceil_mul_of(std::size_t x, std::size_t y) +{ + return y * integer_divide_ceil(x, y); +} + template struct unsigned_type { diff --git a/src/include/migraphx/split_factor.hpp b/src/include/migraphx/split_factor.hpp new file mode 100644 index 00000000000..cb67695230d --- /dev/null +++ b/src/include/migraphx/split_factor.hpp @@ -0,0 +1,75 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_SPLIT_FACTOR_HPP +#define MIGRAPHX_GUARD_SPLIT_FACTOR_HPP + +#include +#include +#include +#include +#include +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +/** + * Calculate split factor with or without maximum splits constraint. + * + * Used by: + * - rewrite_topk: Splits large topk operations for better performance + * - flash_decoding: Splits attention sequence dimension for parallelization + * - split_reduce (GPU JIT): Splits reduction operations + * + * To compute the number of split groups it finds the largest + * divisor that can divide dimension to make it less than min_size. + * + * @param r The value to split. This is passed by reference and will be modified to the remaining + * value after splitting. + * @param min_size Target threshold - splits until remaining size is less than this value + * @param max_splits Target threshold - if reached, returns the smallest split factor greater than + * or equal to max_splits that evenly divides dimension. Optional + * @return The split factor that respects both constraints + */ + +inline std::size_t split_dim(std::size_t& r, + std::size_t min_size, + std::size_t max_splits = std::numeric_limits::max()) +{ + std::size_t n = 1; + auto factors = make_array(2, 3, 5, 7, 11); + while(r > min_size and n < max_splits) + { + // NOLINTNEXTLINE(readability-qualified-auto) + auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); + if(it == factors.end()) + break; + r /= *it; + n *= *it; + } + return n; +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif // MIGRAPHX_GUARD_SPLIT_FACTOR_HPP diff --git a/src/rewrite_topk.cpp b/src/rewrite_topk.cpp index 55fd0e2bb79..19411680db8 100644 --- a/src/rewrite_topk.cpp +++ b/src/rewrite_topk.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,6 +27,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -38,23 +39,6 @@ struct find_large_topk std::size_t n_threshold = 0; auto matcher() const { return match::name("topk"); } - static std::size_t split_dim(std::size_t& r, std::size_t min_size) - { - std::size_t n = 1; - auto factors = make_array(2, 3, 5, 7, 11); - while(r > min_size) - { - // NOLINTNEXTLINE(readability-qualified-auto) - auto it = - std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); - if(it == factors.end()) - break; - r /= *it; - n *= *it; - } - return n; - } - void apply(module& m, const match::matcher_result& r) const { auto ins = r.result; diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index cf3cf6f3416..c8f0b0f154f 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -48,6 +48,7 @@ namespace gpu { MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_EXTRA_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_INPUT_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_REDUCE_FUSION); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_ENABLED); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR_GEG_FUSION); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_MLIR); /** @@ -143,6 +144,23 @@ bool mlir_attention_enabled(context* ctx) #endif } +bool mlir_flash_decoding_enabled() +{ +#ifdef MIGRAPHX_MLIR + if(not mlir_enabled()) + return false; + + // Check if explicitly enabled via environment variable + if(enabled(MIGRAPHX_FLASH_DECODING_ENABLED{})) + return true; + + return false; +#else + // Without MLIR, only enable if explicitly requested via env var + return enabled(MIGRAPHX_FLASH_DECODING_ENABLED{}); +#endif +} + #ifdef MIGRAPHX_MLIR struct mlir_op diff --git a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp index 82edc35e581..715fe1057a3 100644 --- a/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp +++ b/src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -35,6 +35,7 @@ namespace gpu { MIGRAPHX_GPU_EXPORT bool mlir_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled(context* ctx); +MIGRAPHX_GPU_EXPORT bool mlir_flash_decoding_enabled(); MIGRAPHX_GPU_EXPORT bool mlir_geg_multi_user_intermediates_supported(); struct MIGRAPHX_GPU_EXPORT fuse_mlir diff --git a/src/targets/gpu/jit/reduce.cpp b/src/targets/gpu/jit/reduce.cpp index c8fcf89b04c..3e8682d5f37 100644 --- a/src/targets/gpu/jit/reduce.cpp +++ b/src/targets/gpu/jit/reduce.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include namespace migraphx { @@ -175,18 +175,8 @@ static std::vector split_reduce(const std::vector& inputs, assert(faxis < reduce_shape.lens().size()); - std::size_t n = 1; - auto r = input_shape.lens()[faxis]; - auto factors = make_array(2, 3, 5, 7, 11); - while(r > min_size and n < max_splits) - { - // NOLINTNEXTLINE(readability-qualified-auto) - auto it = std::find_if(factors.begin(), factors.end(), [&](auto d) { return r % d == 0; }); - if(it == factors.end()) - break; - r /= *it; - n *= *it; - } + std::size_t r = input_shape.lens()[faxis]; + std::size_t n = split_dim(r, min_size, max_splits); assert(n != 1); std::transform( inputs.begin(), inputs.end(), std::back_inserter(result), [&](const shape& s) -> shape { diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index de2da57ebc4..891d80631bc 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -225,7 +225,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, simplify_reshapes{.enable_op_shape_transform_op=true}, dead_code_elimination{}, - enable_pass(mlir_enabled(), fuse_attention{mlir_attention_enabled(&ctx)}), + enable_pass(mlir_enabled(), fuse_attention{.attn_enabled = mlir_attention_enabled(&ctx), + .flash_decoding_enabled = mlir_flash_decoding_enabled()}), dead_code_elimination{}, optimize_module{}, enable_pass(disabled(MIGRAPHX_ENABLE_FULL_DYNAMIC{}), fuse_pointwise_reduce{}), diff --git a/test/fuse_attention.cpp b/test/fuse_attention.cpp index 41b145fc267..a662ec89e6e 100644 --- a/test/fuse_attention.cpp +++ b/test/fuse_attention.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,20 +31,93 @@ #include #include #include +#include +#include +#include #include #include #include #include #include #include - -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS); +#include static void run_pass(migraphx::program& p, migraphx::fuse_attention fa = {}) { migraphx::run_passes(p, {fa, migraphx::dead_code_elimination{}}); } +// Test helper functions used in fuse_attention pass +TEST_CASE(get_num_splits_from_member) +{ + // Test that member variable takes precedence over environment variable + migraphx::fuse_attention fa; + fa.flash_decoding_num_splits = 8; + + // Test that struct members are set correctly + EXPECT(fa.flash_decoding_num_splits == 8); + EXPECT(fa.flash_decoding_threshold == 32); // default value + EXPECT(fa.flash_decoding_max_splits == 16); // default value + EXPECT(fa.flash_decoding_min_chunk_size == 32); // default value +} + +TEST_CASE(calculate_flash_decoding_splits_basic) +{ + // sequence_length that can be split evenly + // 256 with min_chunk=32 should split to 8 (256/8 = 32) + std::size_t seq_len1 = 256; + std::size_t result1 = migraphx::split_dim(seq_len1, 32, 16); + EXPECT(result1 == 8); + + // sequence_length with max_splits constraint + // 1024 with min_chunk=64 and max_splits=8 should be limited to 8 + std::size_t seq_len2 = 1024; + std::size_t result2 = migraphx::split_dim(seq_len2, 64, 8); + EXPECT(result2 == 8); + + // small sequence that shouldn't be split + // 32 with min_chunk=32 should return 1 (no split) + std::size_t seq_len3 = 32; + std::size_t result3 = migraphx::split_dim(seq_len3, 32, 16); + EXPECT(result3 == 1); + + // prime number sequence length + // 97 with min_chunk=10 should return 1 (can't split prime) + std::size_t seq_len4 = 97; + std::size_t result4 = migraphx::split_dim(seq_len4, 10, 16); + EXPECT(result4 == 1); + + // typical attention sequence lengths + // 2048 with min_chunk=128 and max_splits=16 + std::size_t seq_len5 = 2048; + std::size_t result5 = migraphx::split_dim(seq_len5, 128, 16); + EXPECT(result5 == 16); +} + +TEST_CASE(padding_calculation) +{ + // not evenly divisible - padding needed + std::size_t seq_len2 = 100; + std::size_t groups2 = 8; + // 100 % 8 != 0, so padding is needed + std::size_t padding2 = migraphx::ceil_mul_of(seq_len2, groups2) - seq_len2; + EXPECT(padding2 == 4); // 104 - 100 = 4 + + // sequence length = 127, groups = 16 + std::size_t seq_len3 = 127; + std::size_t groups3 = 16; + // 127 % 16 != 0, so padding is needed + std::size_t padding3 = migraphx::ceil_mul_of(seq_len3, groups3) - seq_len3; + EXPECT(padding3 == 1); // 128 - 127 = 1 + + // large sequence with padding + std::size_t seq_len4 = 2049; + std::size_t groups4 = 32; + // 2049 % 32 != 0, so padding is needed + std::size_t padding4 = migraphx::ceil_mul_of(seq_len4, groups4) - seq_len4; + EXPECT(padding4 == 31); // 2080 - 2049 = 31 +} + TEST_CASE(gemm_softmax_gemm) { migraphx::shape s1{migraphx::shape::half_type, {1, 12, 256, 256}}; @@ -445,7 +518,8 @@ TEST_CASE(gemm_softmax_gemm_flash_decoding) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 2}); + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { auto* mm = p2.get_main_module(); @@ -546,7 +620,8 @@ TEST_CASE(flash_decoding_3d) auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 2}); + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 2}); migraphx::program p2; { @@ -635,67 +710,266 @@ TEST_CASE(flash_decoding_3d) EXPECT(p1.sort() == p2.sort()); } -TEST_CASE(flash_decoding_disabled) +TEST_CASE(flash_decoding_3d_rectangular) { - migraphx::shape s1{migraphx::shape::half_type, {1, 12, 256, 256}}; + // 3D Shape: [batch, head_dim, sequence_length] + migraphx::shape s_3d{migraphx::shape::half_type, {1, 256, 240}}; + migraphx::shape st_3d{migraphx::shape::half_type, {1, 240, 256}}; + const std::size_t num_splits = 2; migraphx::program p1; { auto* mm = p1.get_main_module(); - auto a = mm->add_parameter("1", s1); - auto b = mm->add_parameter("2", s1); - auto b1 = mm->add_parameter("3", s1); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - b1); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); - rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), - rmax); - auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); - rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), - rsum); - auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + auto a = mm->add_parameter("q", s_3d); // [1, 256, 240] + auto b = mm->add_parameter("k", s_3d); // [1, 256, 240] + auto b1 = mm->add_parameter("v", st_3d); // [1, 240, 256] + a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), + a); // [1, 240, 256] + auto gemm1 = mm->add_instruction( + migraphx::make_op("dot"), a, b); // [1, 240, 256] x [1, 256, 240] = [1, 240, 240] + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), + gemm1); // [1, 240, 1] + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rmax); // [1, 240, 240] + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // [1, 240, 240] + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 240, 240] + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), + exp); // [1, 240, 1] + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rsum); // [1, 240, 240] + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // [1, 240, 240] + auto gemm2 = mm->add_instruction( + migraphx::make_op("dot"), div, b1); // [1, 240, 240] x [1, 240, 256] = [1, 240, 256] mm->add_return({gemm2}); } - run_pass(p1, {.attn_enabled = true, .flash_decoding_num_splits = 0}); + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = num_splits}); - // Expected result: only attention fusion, no flash decoding migraphx::program p2; { - auto* mm = p2.get_main_module(); - auto a = mm->add_parameter("1", s1); - auto b = mm->add_parameter("2", s1); - auto b1 = mm->add_parameter("3", s1); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - b1); + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", st_3d); + size_t g_axis = 1; + + // New shapes for flash decoding - 240 split into 2x120 + std::vector q_prime_shape = {1, num_splits, 240, 256}; + std::vector k_prime_shape = {1, num_splits, 256, 120}; + std::vector v_prime_shape = {1, num_splits, 120, 256}; + + auto a_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); + auto a_unsqueeze = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {g_axis}}}), a_transpose); + auto a_broadcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); + + // K: [1, 256, 240] -> [1, 256, 2, 120] -> [1, 2, 256, 120] + auto b_reshape_intermediate = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 256, 2, 120}}}), b); + auto b_reshape = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + b_reshape_intermediate); + + // V: [1, 240, 256] -> [1, 2, 120, 256] -> [1, 2, 120, 256] + auto b1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1); + auto group = add_group( p2, - "attn0", + "attn0_flash_decoding", "attention", - {a, b, b1}, + {a_broadcast, b_reshape, b1_reshape}, {"x0", "x1", "x2"}, - [=](auto* gm, const auto& inputs) { + [&](auto* gm, const auto& inputs) { auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); auto rmax = gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); - rmax = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rmax); - auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 120}}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); auto rsum = gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); - rsum = gm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), rsum); - auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 120}}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); - return std::vector{gemm2}; + auto log = gm->add_instruction(migraphx::make_op("log"), rsum); + auto add = gm->add_instruction(migraphx::make_op("add"), rmax, log); + return std::vector{gemm2, add}; }); - mm->add_return({group}); + auto o_p = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), group); + auto lse = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), group); + + // Kernel 2 + auto k2_rmax = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); + auto k2_broad1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 1}}}), k2_rmax); + auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); + auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); + auto k2_rsum1 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); + auto k2_broad2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 240, 1}}}), k2_rsum1); + auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); + auto k2_broad3 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); + auto k2_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), k2_broad3); + auto k2_mul = mm->add_instruction(migraphx::make_op("mul"), o_p, k2_convert); + auto k2_rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_mul); + auto k2_squeeze = + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {g_axis}}}), k2_rsum2); + mm->add_return({k2_squeeze}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(flash_decoding_3d_padding) +{ + // 3D Shape: [batch, head_dim, sequence_length] + migraphx::shape s_3d{migraphx::shape::half_type, {1, 256, 241}}; + migraphx::shape st_3d{migraphx::shape::half_type, {1, 241, 256}}; + const std::size_t num_splits = 2; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); // [1, 256, 241] + auto b = mm->add_parameter("k", s_3d); // [1, 256, 241] + auto b1 = mm->add_parameter("v", st_3d); // [1, 241, 256] + a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), + a); // [1, 241, 256] + auto gemm1 = mm->add_instruction( + migraphx::make_op("dot"), a, b); // [1, 241, 256] x [1, 256, 241] = [1, 241, 241] + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), + gemm1); // [1, 241, 1] + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rmax); // [1, 241, 241] + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // [1, 241, 241] + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // [1, 241, 241] + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), + exp); // [1, 241, 1] + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rsum); // [1, 241, 241] + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // [1, 241, 241] + auto gemm2 = mm->add_instruction( + migraphx::make_op("dot"), div, b1); // [1, 241, 241] x [1, 241, 256] = [1, 241, 256] + mm->add_return({gemm2}); + } + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = num_splits}); + + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", st_3d); + size_t g_axis = 1; + + // Padding 241 -> 242, split into 2x121 + std::vector q_prime_shape = {1, num_splits, 242, 256}; + std::vector k_prime_shape = {1, num_splits, 256, 121}; + std::vector v_prime_shape = {1, num_splits, 121, 256}; + + // Q: [1, 256, 241] -> transpose -> [1, 241, 256] -> pad -> [1, 242, 256] + auto a_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), a); + auto a_padded = mm->add_instruction( + migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 1, 0}}}), a_transpose); + auto a_unsqueeze = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {g_axis}}}), a_padded); + auto a_broadcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); + + // K: [1, 256, 241] -> pad -> [1, 256, 242] -> reshape -> [1, 256, 2, 121] -> transpose -> + // [1, 2, 256, 121] + auto b_padded = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 0, 1}}}), b); + auto b_reshape_intermediate = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 256, 2, 121}}}), b_padded); + auto b_reshape = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + b_reshape_intermediate); + + // V: [1, 241, 256] -> pad -> [1, 242, 256] -> reshape -> [1, 2, 121, 256] + auto b1_padded = + mm->add_instruction(migraphx::make_op("pad", {{"pads", {0, 0, 0, 0, 1, 0}}}), b1); + auto b1_reshape = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1_padded); + + auto group = add_group( + p2, + "attn0_flash_decoding", + "attention", + {a_broadcast, b_reshape, b1_reshape}, + {"x0", "x1", "x2"}, + [&](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 121}}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 121}}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + auto log = gm->add_instruction(migraphx::make_op("log"), rsum); + auto add = gm->add_instruction(migraphx::make_op("add"), rmax, log); + return std::vector{gemm2, add}; + }); + auto o_p = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), group); + auto lse = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), group); + + // Kernel 2 + auto k2_rmax = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); + auto k2_broad1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 1}}}), k2_rmax); + auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); + auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); + auto k2_rsum1 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); + auto k2_broad2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, num_splits, 242, 1}}}), k2_rsum1); + auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); + auto k2_broad3 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); + auto k2_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), k2_broad3); + auto k2_mul = mm->add_instruction(migraphx::make_op("mul"), o_p, k2_convert); + auto k2_rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_mul); + auto k2_squeeze = + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {g_axis}}}), k2_rsum2); + + // Slice to remove padding: [1, 242, 256] -> [1, 241, 256] + auto sliced = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {241}}}), + k2_squeeze); + + mm->add_return({sliced}); } EXPECT(p1.sort() == p2.sort()); } @@ -878,6 +1152,428 @@ TEST_CASE(kv_cache_attention) EXPECT(p1.sort() == p2.sort()); } +// Test automatic splitting with num_splits = 0 (auto-calculate) +TEST_CASE(flash_decoding_3d_auto_split_large_sequence) +{ + // 3D Shape: [batch, sequence_length, head_dim] - Use larger sequence to trigger auto-splitting + migraphx::shape s_3d{migraphx::shape::half_type, {1, 512, 512}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Use auto-splitting: num_splits = 0, with sequence length 512 > threshold 32 + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); + + // Expected program with automatic splitting (should calculate 16 splits for 512 sequence) + const std::size_t expected_splits = 16; // 512 = 2^9, split until chunk = 32, so 512/16 = 32 + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + size_t g_axis = 1; + + // New shapes for flash decoding with calculated splits + std::vector q_prime_shape = {1, expected_splits, 512, 512}; + std::vector k_prime_shape = {1, expected_splits, 512, 32}; // 512/16 = 32 + std::vector v_prime_shape = {1, expected_splits, 32, 512}; + + auto a_unsqueeze = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {g_axis}}}), a); + auto a_broadcast = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), a_unsqueeze); + + auto b_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + // K: [1, 512, 512] -> [1, 512, 16, 32] -> [1, 16, 512, 32] + auto b_reshape_intermediate = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", {1, 512, expected_splits, 32}}}), b_transpose); + auto b_reshape = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), + b_reshape_intermediate); + + auto b1_transpose = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto b1_reshape = mm->add_instruction( + migraphx::make_op("reshape", {{"dims", v_prime_shape}}), b1_transpose); + + auto group = add_group( + p2, + "attn0_flash_decoding", + "attention", + {a_broadcast, b_reshape, b1_reshape}, + {"x0", "x1", "x2"}, + [&](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", + {{"out_lens", {1, expected_splits, 512, 32}}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", + {{"out_lens", {1, expected_splits, 512, 32}}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + auto log = gm->add_instruction(migraphx::make_op("log"), rsum); + auto add = gm->add_instruction(migraphx::make_op("add"), rmax, log); + return std::vector{gemm2, add}; + }); + auto o_p = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 0}}), group); + auto lse = mm->add_instruction(migraphx::make_op("get_tuple_elem", {{"index", 1}}), group); + + // Kernel 2 + auto k2_rmax = + mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {g_axis}}}), lse); + auto k2_broad1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), + k2_rmax); + auto k2_sub = mm->add_instruction(migraphx::make_op("sub"), lse, k2_broad1); + auto k2_exp = mm->add_instruction(migraphx::make_op("exp"), k2_sub); + auto k2_rsum1 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_exp); + auto k2_broad2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, expected_splits, 512, 1}}}), + k2_rsum1); + auto k2_div = mm->add_instruction(migraphx::make_op("div"), k2_exp, k2_broad2); + auto k2_broad3 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", q_prime_shape}}), k2_div); + auto k2_convert = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), k2_broad3); + auto k2_mul = mm->add_instruction(migraphx::make_op("mul"), o_p, k2_convert); + auto k2_rsum2 = + mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {g_axis}}}), k2_mul); + auto k2_squeeze = + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {g_axis}}}), k2_rsum2); + mm->add_return({k2_squeeze}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(flash_decoding_3d_auto_split_small_sequence) +{ + // 3D Shape: [batch, sequence_length, head_dim] - Small sequence that should NOT trigger + // splitting + migraphx::shape s_3d{migraphx::shape::half_type, {1, 16, 16}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Use auto-splitting: num_splits = 0, with small sequence length 16 < threshold 32 + run_pass( + p1, {.attn_enabled = true, .flash_decoding_enabled = true, .flash_decoding_num_splits = 0}); + + // Expected program with regular attention (no flash decoding for small sequence) + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto group = add_group( + p2, + "attn0", + "attention", + {a, b, b1}, + {"x0", "x1", "x2"}, + [=](auto* gm, const auto& inputs) { + auto gemm1 = gm->add_instruction(migraphx::make_op("dot"), inputs[0], inputs[1]); + auto rmax = + gm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + auto rmax_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rmax); + auto sub = gm->add_instruction(migraphx::make_op("sub"), gemm1, rmax_broad); + auto exp = gm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = + gm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + auto rsum_broad = gm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), + rsum); + auto div = gm->add_instruction(migraphx::make_op("div"), exp, rsum_broad); + auto gemm2 = gm->add_instruction(migraphx::make_op("dot"), div, inputs[2]); + return std::vector{gemm2}; + }); + mm->add_return({group}); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(flash_decoding_4d_auto_split_custom_params) +{ + // 4D Shape: [batch, heads, sequence_length, head_dim] - Test with custom parameters + migraphx::shape s1{migraphx::shape::half_type, {1, 12, 256, 256}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("1", s1); + auto b = mm->add_parameter("2", s1); + auto b1 = mm->add_parameter("3", s1); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {3}}}), gemm1); + rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {3}}}), exp); + rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), + rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Test with custom min_chunk_size and max_splits + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, // Auto-calculate + .flash_decoding_threshold = 32, + .flash_decoding_max_splits = 4, // Smaller max splits + .flash_decoding_min_chunk_size = 64}); // Larger chunk size + + // Check for flash decoding + auto* mod = p1.get_main_module(); + bool found_flash_decoding = std::any_of(mod->begin(), mod->end(), [](const auto& ins) { + return ins.name().find("group") != std::string::npos; + }); + + EXPECT(found_flash_decoding); +} + +TEST_CASE(flash_decoding_auto_split_threshold_behavior) +{ + // Test threshold behavior - sequence right at the threshold boundary + migraphx::shape s_3d{migraphx::shape::half_type, {1, 127, 127}}; + + migraphx::program p1; + migraphx::program p2; + + // Test 1: sequence length below threshold - should NOT split + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, + .flash_decoding_threshold = 128, // Greater than sequence length (127) + .flash_decoding_max_splits = 8, + .flash_decoding_min_chunk_size = 32}); + + // Test 2: sequence length at threshold - should split + migraphx::shape s_3d_larger{migraphx::shape::half_type, {1, 128, 128}}; + { + auto* mm = p2.get_main_module(); + auto a = mm->add_parameter("q", s_3d_larger); + auto b = mm->add_parameter("k", s_3d_larger); + auto b1 = mm->add_parameter("v", s_3d_larger); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + run_pass(p2, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, + .flash_decoding_threshold = 128, // Equal to sequence length (128) + .flash_decoding_max_splits = 8, + .flash_decoding_min_chunk_size = 32}); + + // Check results - look for flash decoding by checking module names + bool found_flash_decoding_p1 = false; + bool found_flash_decoding_p2 = false; + bool found_regular_attention_p1 = false; + + for(const auto& ins : *p1.get_main_module()) + { + if(ins.name().find("group") != std::string::npos) + { + // Check the module name to distinguish flash decoding from regular attention + const auto& module_inputs = ins.module_inputs(); + if(not module_inputs.empty()) + { + const auto& mod_name = module_inputs[0]->name(); + if(mod_name.find("flash_decoding") != std::string::npos) + { + found_flash_decoding_p1 = true; + } + else + { + found_regular_attention_p1 = true; + } + } + } + } + + for(const auto& ins : *p2.get_main_module()) + { + if(ins.name().find("group") != std::string::npos) + { + const auto& module_inputs = ins.module_inputs(); + if(not module_inputs.empty()) + { + const auto& mod_name = module_inputs[0]->name(); + if(mod_name.find("flash_decoding") != std::string::npos) + { + found_flash_decoding_p2 = true; + } + } + } + } + + // Below threshold: should have regular attention, not flash decoding + EXPECT(not found_flash_decoding_p1); + EXPECT(found_regular_attention_p1); // Should have regular attention instead + // At threshold: should have flash decoding + EXPECT(found_flash_decoding_p2); +} + +TEST_CASE(flash_decoding_auto_split_max_splits_constraint) +{ + // Test that max_splits constraint is respected + migraphx::shape s_3d{migraphx::shape::half_type, {1, 2048, 2048}}; + + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + auto a = mm->add_parameter("q", s_3d); + auto b = mm->add_parameter("k", s_3d); + auto b1 = mm->add_parameter("v", s_3d); + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); + rsum = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", gemm1->get_shape().lens()}}), rsum); + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + mm->add_return({gemm2}); + } + // Use small max_splits to test constraint + run_pass(p1, + {.attn_enabled = true, + .flash_decoding_enabled = true, + .flash_decoding_num_splits = 0, // Auto-calculate + .flash_decoding_threshold = 32, + .flash_decoding_max_splits = 4, // Small max_splits + .flash_decoding_min_chunk_size = 64}); + + // Check that flash decoding was applied + auto* mod = p1.get_main_module(); + bool found_flash_decoding = std::any_of(mod->begin(), mod->end(), [](const auto& ins) { + return ins.name().find("group") != std::string::npos; + }); + + EXPECT(found_flash_decoding); +} + +TEST_CASE(ceil_mul_of_function) +{ + // Test the ceil_mul_of function used for padding calculations + + // Test exact multiples - no padding needed + EXPECT(migraphx::ceil_mul_of(16, 4) == 16); // 16 is multiple of 4 + EXPECT(migraphx::ceil_mul_of(32, 8) == 32); // 32 is multiple of 8 + EXPECT(migraphx::ceil_mul_of(100, 10) == 100); // 100 is multiple of 10 + + // Test non-multiples - padding needed + EXPECT(migraphx::ceil_mul_of(17, 4) == 20); // 17 -> 20 (next multiple of 4) + EXPECT(migraphx::ceil_mul_of(33, 8) == 40); // 33 -> 40 (next multiple of 8) + EXPECT(migraphx::ceil_mul_of(101, 10) == 110); // 101 -> 110 (next multiple of 10) + + // Test edge cases + EXPECT(migraphx::ceil_mul_of(1, 4) == 4); // 1 -> 4 + EXPECT(migraphx::ceil_mul_of(0, 4) == 0); // 0 -> 0 + + // Test specific cases from attention fusion + EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 100 -> 104 (padding = 4) + EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 127 -> 128 (padding = 1) + EXPECT(migraphx::ceil_mul_of(2049, 32) == 2080); // 2049 -> 2080 (padding = 31) +} + int main(int argc, const char* argv[]) { test::run(argc, argv); diff --git a/test/math_utils_test.cpp b/test/math_utils_test.cpp new file mode 100644 index 00000000000..e0a3c9099ab --- /dev/null +++ b/test/math_utils_test.cpp @@ -0,0 +1,467 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include "test.hpp" + +TEST_CASE(integer_divide_ceil_basic) +{ + // Test exact division + EXPECT(migraphx::integer_divide_ceil(10, 5) == 2); + EXPECT(migraphx::integer_divide_ceil(12, 4) == 3); + EXPECT(migraphx::integer_divide_ceil(15, 3) == 5); + + // Test division with remainder (should round up) + EXPECT(migraphx::integer_divide_ceil(10, 3) == 4); // 10/3 = 3.33... -> 4 + EXPECT(migraphx::integer_divide_ceil(11, 3) == 4); // 11/3 = 3.66... -> 4 + EXPECT(migraphx::integer_divide_ceil(13, 5) == 3); // 13/5 = 2.6 -> 3 + + // Test with 1 + EXPECT(migraphx::integer_divide_ceil(1, 1) == 1); + EXPECT(migraphx::integer_divide_ceil(5, 1) == 5); + EXPECT(migraphx::integer_divide_ceil(100, 1) == 100); + + // Test edge cases + EXPECT(migraphx::integer_divide_ceil(0, 5) == 0); + EXPECT(migraphx::integer_divide_ceil(1, 2) == 1); + EXPECT(migraphx::integer_divide_ceil(1, 10) == 1); +} + +TEST_CASE(integer_divide_ceil_large_numbers) +{ + // Test with larger numbers + EXPECT(migraphx::integer_divide_ceil(1000, 7) == 143); // 1000/7 = 142.857... -> 143 + EXPECT(migraphx::integer_divide_ceil(1024, 32) == 32); // Exact division + EXPECT(migraphx::integer_divide_ceil(1025, 32) == 33); // 1025/32 = 32.03... -> 33 + + // Test with powers of 2 + EXPECT(migraphx::integer_divide_ceil(127, 8) == 16); // 127/8 = 15.875 -> 16 + EXPECT(migraphx::integer_divide_ceil(128, 8) == 16); // Exact division + EXPECT(migraphx::integer_divide_ceil(129, 8) == 17); // 129/8 = 16.125 -> 17 +} + +TEST_CASE(ceil_mul_of_basic) +{ + // Test exact multiples (no rounding needed) + EXPECT(migraphx::ceil_mul_of(10, 5) == 10); // 10 is already a multiple of 5 + EXPECT(migraphx::ceil_mul_of(12, 4) == 12); // 12 is already a multiple of 4 + EXPECT(migraphx::ceil_mul_of(15, 3) == 15); // 15 is already a multiple of 3 + + // Test rounding up to next multiple + EXPECT(migraphx::ceil_mul_of(11, 5) == 15); // Next multiple of 5 after 11 is 15 + EXPECT(migraphx::ceil_mul_of(13, 4) == 16); // Next multiple of 4 after 13 is 16 + EXPECT(migraphx::ceil_mul_of(17, 3) == 18); // Next multiple of 3 after 17 is 18 + + // Test with 1 (should always return the original number) + EXPECT(migraphx::ceil_mul_of(5, 1) == 5); + EXPECT(migraphx::ceil_mul_of(100, 1) == 100); + + // Test edge cases + EXPECT(migraphx::ceil_mul_of(0, 5) == 0); + EXPECT(migraphx::ceil_mul_of(1, 5) == 5); + EXPECT(migraphx::ceil_mul_of(1, 10) == 10); +} + +TEST_CASE(ceil_mul_of_powers_of_two) +{ + // Test with powers of 2 (common in GPU programming) + EXPECT(migraphx::ceil_mul_of(100, 32) == 128); // 32 * 4 = 128 + EXPECT(migraphx::ceil_mul_of(128, 32) == 128); // Already aligned + EXPECT(migraphx::ceil_mul_of(129, 32) == 160); // 32 * 5 = 160 + + EXPECT(migraphx::ceil_mul_of(250, 64) == 256); // 64 * 4 = 256 + EXPECT(migraphx::ceil_mul_of(256, 64) == 256); // Already aligned + EXPECT(migraphx::ceil_mul_of(257, 64) == 320); // 64 * 5 = 320 + + // Warp size alignment (32 threads) + EXPECT(migraphx::ceil_mul_of(30, 32) == 32); + EXPECT(migraphx::ceil_mul_of(32, 32) == 32); + EXPECT(migraphx::ceil_mul_of(33, 32) == 64); + EXPECT(migraphx::ceil_mul_of(64, 32) == 64); + EXPECT(migraphx::ceil_mul_of(65, 32) == 96); +} + +TEST_CASE(ceil_mul_of_flash_attention_use_case) +{ + // Test the specific use case from flash attention + // Simulating the padding of sequence length to be divisible by number of groups + + // Example 1: sequence_length=100, num_groups=8 + EXPECT(migraphx::ceil_mul_of(100, 8) == 104); // 8 * 13 = 104 + + // Example 2: sequence_length=127, num_groups=16 + EXPECT(migraphx::ceil_mul_of(127, 16) == 128); // 16 * 8 = 128 + + // Example 3: sequence_length=200, num_groups=32 + EXPECT(migraphx::ceil_mul_of(200, 32) == 224); // 32 * 7 = 224 + + // Example 4: Already divisible + EXPECT(migraphx::ceil_mul_of(192, 32) == 192); // Already divisible by 32 +} + +TEST_CASE(ceil_mul_of_consistency_with_integer_divide_ceil) +{ + // Verify that ceil_mul_of(x, y) == y * integer_divide_ceil(x, y) + // This tests the implementation relationship + + const std::size_t test_cases[][2] = { + {10, 3}, {15, 4}, {100, 7}, {256, 32}, {1000, 13}, {1, 10}, {0, 5}}; + + for(const auto* const tc : test_cases) + { + std::size_t x = tc[0]; + std::size_t y = tc[1]; + + std::size_t expected = y * migraphx::integer_divide_ceil(x, y); + std::size_t actual = migraphx::ceil_mul_of(x, y); + + EXPECT(actual == expected); + } +} + +TEST_CASE(ceil_mul_of_large_numbers) +{ + // Test with larger sequence lengths that might appear in real models + EXPECT(migraphx::ceil_mul_of(1024, 16) == 1024); // Already aligned + EXPECT(migraphx::ceil_mul_of(1025, 16) == 1040); // 16 * 65 = 1040 + EXPECT(migraphx::ceil_mul_of(2048, 64) == 2048); // Already aligned + EXPECT(migraphx::ceil_mul_of(2049, 64) == 2112); // 64 * 33 = 2112 + + // Very large numbers + EXPECT(migraphx::ceil_mul_of(10000, 128) == 10112); // 128 * 79 = 10112 + EXPECT(migraphx::ceil_mul_of(16384, 256) == 16384); // Already aligned + EXPECT(migraphx::ceil_mul_of(16385, 256) == 16640); // 256 * 65 = 16640 +} + +TEST_CASE(split_dim_basic) +{ + // Test with no max_splits constraint (using default) + + // Should split 100 into chunks > 10 + // 100 = 2^2 * 5^2; factors: 2,2,5,5 -> splits = 20, remaining = 5 + std::size_t dim100 = 100; + std::size_t result100 = migraphx::split_dim(dim100, 10); + EXPECT(result100 == 20); // 100/20 = 5, stops because 5 <= 10 + + // Should split 64 into chunks > 10 + // 64 = 2^6; can use 2,2,2 -> splits = 8, remaining = 8 + std::size_t dim64 = 64; + std::size_t result64 = migraphx::split_dim(dim64, 10); + EXPECT(result64 == 8); // 64/8 = 8, stops because 8 <= 10 + + // Should not split if already small enough + std::size_t dim10 = 10; + std::size_t result10 = migraphx::split_dim(dim10, 10); + EXPECT(result10 == 1); // 10 is not > 10, so no split + + std::size_t dim11 = 11; + std::size_t result11 = migraphx::split_dim(dim11, 10); + EXPECT(result11 == 11); // 11 is a factor itself, 11/11 = 1 + + // Prime numbers that can't be factored + std::size_t dim13 = 13; + std::size_t result13 = migraphx::split_dim(dim13, 10); + EXPECT(result13 == 1); // 13 is prime (not in factor list) + + std::size_t dim17 = 17; + std::size_t result17 = migraphx::split_dim(dim17, 10); + EXPECT(result17 == 1); // 17 is prime (not in factor list) + + // Numbers with factors in [2,3,5,7,11] + std::size_t dim30 = 30; + std::size_t result30 = migraphx::split_dim(dim30, 5); + EXPECT(result30 == 6); // 30 = 2*3*5, splits to 5 + + std::size_t dim77 = 77; + std::size_t result77 = migraphx::split_dim(dim77, 10); + EXPECT(result77 == + 77); // can be evenly split into 11 size chunks; next divisor splits to 1 size chunks +} + +TEST_CASE(split_dim_with_max_splits) +{ + // Test with explicit max_splits constraint + // Note: max_splits is NOT a hard cap - function returns smallest split factor > max_splits that + // evenly divides dimension + + // When split factor would exceed max_splits, returns next valid divisor + std::size_t dim100a = 100; + std::size_t result100a = migraphx::split_dim(dim100a, 10, 4); + EXPECT(result100a == + 4); // 100 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + + std::size_t dim100b = 100; + std::size_t result100b = migraphx::split_dim(dim100b, 10, 2); + EXPECT(result100b == 2); + + // Max splits doesn't force splitting if min_size constraint would be violated + std::size_t dim20 = 20; + std::size_t result20 = migraphx::split_dim(dim20, 10, 4); + EXPECT(result20 == 2); // Can only split to 2 (20/2=10, not > 10) + + std::size_t dim15 = 15; + std::size_t result15 = migraphx::split_dim(dim15, 10, 4); + EXPECT(result15 == 3); // 15 = 3*5, splits to 3, remaining = 5 + + // Test with powers of 2 + std::size_t dim128a = 128; + std::size_t result128a = migraphx::split_dim(dim128a, 10, 8); + EXPECT( + result128a == + 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 + + std::size_t dim128b = 128; + std::size_t result128b = migraphx::split_dim(dim128b, 10, 4); + EXPECT(result128b == + 4); // 128 can be divided by 2*2, which is 4 splits, which is not less than max_splits=4 + + std::size_t dim128c = 128; + std::size_t result128c = migraphx::split_dim(dim128c, 20, 8); + EXPECT( + result128c == + 8); // 128 can be divided by 2*2*2, which is 8 splits, which is not less than max_splits=8 +} + +TEST_CASE(split_dim_edge_cases) +{ + // Test edge cases + + // Very small dimensions + std::size_t dim1 = 1; + std::size_t result1 = migraphx::split_dim(dim1, 0); + EXPECT(result1 == 1); // 1 can't be split + + std::size_t dim2a = 2; + std::size_t result2a = migraphx::split_dim(dim2a, 0); + EXPECT(result2a == 2); // 2/2 = 1 > 0 + + std::size_t dim2b = 2; + std::size_t result2b = migraphx::split_dim(dim2b, 1); + EXPECT(result2b == 2); // 2/2 = 1, but we continue while r > min_size, so 2 > 1 allows split + + // Exact boundary conditions + std::size_t dim20a = 20; + std::size_t result20a = migraphx::split_dim(dim20a, 9); + EXPECT(result20a == 4); // 20 = 2^2 * 5, factors 2,2 before 20/4=5 <= 9 + + std::size_t dim20b = 20; + std::size_t result20b = migraphx::split_dim(dim20b, 10); + EXPECT(result20b == 2); // 20 = 2^2 * 5, factors 2 before 20/2=10 <= 10 + + std::size_t dim21 = 21; + std::size_t result21 = migraphx::split_dim(dim21, 10); + EXPECT(result21 == 3); // 21 = 3*7, splits by 3 first, 21/3 = 7 <= 10 + + // Large prime numbers + std::size_t dim97 = 97; + std::size_t result97 = migraphx::split_dim(dim97, 10); + EXPECT(result97 == 1); // 97 is prime + + std::size_t dim101 = 101; + std::size_t result101 = migraphx::split_dim(dim101, 10); + EXPECT(result101 == 1); // 101 is prime +} + +TEST_CASE(split_dim_factorization_order) +{ + // Test that factorization happens in the expected order [2,3,5,7,11] + + // 60 = 2^2 * 3 * 5 + // With min_size=10: 60->30->15->5 (stops because 5 <= 10) + // Factors used: 2, 2, 3 (product = 12) + std::size_t dim60 = 60; + std::size_t result60 = migraphx::split_dim(dim60, 10); + EXPECT(result60 == 12); // 60/12 = 5 + + // 210 = 2 * 3 * 5 * 7 + // With min_size=20: continues factoring while 210 > 20 + // Factors all: 2*3*5*7 = 210, but stops at 2*3*5 = 30 since 210/30 = 7 <= 20 + std::size_t dim210 = 210; + std::size_t result210 = migraphx::split_dim(dim210, 20); + EXPECT(result210 == 30); // 210/30 = 7 + + // 462 = 2 * 3 * 7 * 11 + // With min_size=30: 462->231->77->11 (stops because 11 <= 30) + std::size_t dim462 = 462; + std::size_t result462 = migraphx::split_dim(dim462, 30); + EXPECT(result462 == 42); // 462/42 = 11 <= 30 +} + +TEST_CASE(split_dim_reduce_use_case) +{ + // Test the specific use case from reduce.cpp + // These are realistic values that might appear in reduction operations + + // Large reduction dimension with typical min_size and max_splits + std::size_t dim1024 = 1024; + std::size_t result1024 = migraphx::split_dim(dim1024, 64, 16); + EXPECT(result1024 == + 16); // 1024 = 2^10; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + + std::size_t dim1000 = 1000; + std::size_t result1000 = migraphx::split_dim(dim1000, 64, 16); + EXPECT(result1000 == 40); // 1000 = 2^3 * 5^3; n=8, r=125 > 64, so next would be n=40 > 16 + + // Smaller dimensions + std::size_t dim256 = 256; + std::size_t result256 = migraphx::split_dim(dim256, 32, 8); + EXPECT(result256 == 8); // 256 = 2^8; would split to 16 (256/16=16), exceeds max_splits=8 + + std::size_t dim200 = 200; + std::size_t result200 = migraphx::split_dim(dim200, 32, 8); + EXPECT(result200 == 8); // 200 = 2^3 * 5^2; would split to 8 (200/8=25), exceeds max_splits=8 +} + +TEST_CASE(split_dim_flash_attention_use_case) +{ + // Test use cases from flash attention decoding + // Sequence lengths need to be split for parallel processing + + // Typical sequence lengths in attention + std::size_t dim2048 = 2048; + std::size_t result2048 = migraphx::split_dim(dim2048, 128, 16); + EXPECT(result2048 == + 16); // 2048 = 2^11; when n=8 < 16, multiplies to 16; when n=16 < 16 is false, stops + + std::size_t dim4096 = 4096; + std::size_t result4096 = migraphx::split_dim(dim4096, 256, 16); + EXPECT(result4096 == 16); // 4096 = 2^12; splits to 16 (4096/16=256) + + // Non-aligned sequence lengths + // 1536 = 2^9 * 3; would continue past max_splits=16 to get 1536/24=64 < 128 + std::size_t dim1536 = 1536; + std::size_t result1536 = migraphx::split_dim(dim1536, 128, 16); + EXPECT(result1536 == 16); // 1536 = 2^9 * 3; n=16, r=128, stops (r not > 128) + + // 3000 = 2^3 * 3 * 5^3 + std::size_t dim3000 = 3000; + std::size_t result3000 = migraphx::split_dim(dim3000, 200, 16); + EXPECT(result3000 == 24); // 3*5 = 24, 3000/24 = 125 + + // Smaller sequences + std::size_t dim512 = 512; + std::size_t result512 = migraphx::split_dim(dim512, 64, 8); + EXPECT(result512 == 8); // 512 = 2^9; stops at n=8 (512/8=64) + + std::size_t dim768 = 768; + std::size_t result768 = migraphx::split_dim(dim768, 64, 8); + EXPECT(result768 == 8); // 768 = 2^8 * 3; n=8, r=128 > 64, stops (r not > 128) +} + +TEST_CASE(split_dim_no_limit_vs_explicit_max) +{ + // Verify that default (no limit) and explicit max give same results when max is large + std::size_t large_max = std::numeric_limits::max(); + + // These should produce identical results + std::size_t dim1000a = 1000; + std::size_t result1000a = migraphx::split_dim(dim1000a, 10); + std::size_t dim1000b = 1000; + std::size_t result1000b = migraphx::split_dim(dim1000b, 10, large_max); + EXPECT(result1000a == result1000b); + + std::size_t dim512a = 512; + std::size_t result512a = migraphx::split_dim(dim512a, 32); + std::size_t dim512b = 512; + std::size_t result512b = migraphx::split_dim(dim512b, 32, large_max); + EXPECT(result512a == result512b); + + std::size_t dim360a = 360; + std::size_t result360a = migraphx::split_dim(dim360a, 20); + std::size_t dim360b = 360; + std::size_t result360b = migraphx::split_dim(dim360b, 20, large_max); + EXPECT(result360a == result360b); + + // Test that max_splits affects the result (but doesn't necessarily limit it) + // max_splits acts as a threshold - result may exceed it + std::size_t dim1000c = 1000; + std::size_t result1000c = migraphx::split_dim(dim1000c, 10); + std::size_t dim1000d = 1000; + std::size_t result1000d = migraphx::split_dim(dim1000d, 10, 8); + EXPECT(result1000c != result1000d); + + std::size_t dim512c = 512; + std::size_t result512c = migraphx::split_dim(dim512c, 8); + std::size_t dim512d = 512; + std::size_t result512d = migraphx::split_dim(dim512d, 8, 16); + EXPECT(result512c != result512d); +} + +TEST_CASE(split_dim_consistency_check) +{ + // Verify important properties of split_dim + + // Property 1: The algorithm stops when remaining <= min_size + // This means the final remaining might be <= min_size + for(std::size_t dim : {100, 256, 512, 1000, 2048}) + { + for(std::size_t min_size : {8, 16, 32, 64}) + { + std::size_t dim_copy = dim; + std::size_t splits = migraphx::split_dim(dim_copy, min_size); + if(splits > 1) + { + // The algorithm continues while r > min_size, + // so it stops when r <= min_size + // We can't guarantee remaining > min_size, + // but we know the split was valid + EXPECT(splits >= 1); + EXPECT(dim / splits > 0); + } + } + } + + // Property 2: With max_splits, result is smallest valid divisor >= max_splits + // Note: max_splits is NOT a hard cap - it's a threshold like min_size + for(std::size_t dim : {100, 256, 512, 1000}) + { + for(std::size_t max_splits : {4, 8, 16}) + { + std::size_t dim_copy = dim; + std::size_t splits = migraphx::split_dim(dim_copy, 10, max_splits); + // Result should evenly divide dimension + EXPECT(dim % splits == 0); + // Result should make remaining size < min_size (10) + EXPECT(dim / splits < 10 or splits >= max_splits); + } + } + + // Property 3: Increasing min_size decreases or maintains splits + for(std::size_t dim : {256, 512, 1024}) + { + std::size_t dim_copy1 = dim; + std::size_t splits_8 = migraphx::split_dim(dim_copy1, 8); + std::size_t dim_copy2 = dim; + std::size_t splits_16 = migraphx::split_dim(dim_copy2, 16); + std::size_t dim_copy3 = dim; + std::size_t splits_32 = migraphx::split_dim(dim_copy3, 32); + + EXPECT(splits_8 >= splits_16); + EXPECT(splits_16 >= splits_32); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/verify/test_attention_flash_decoding_3d.cpp b/test/verify/test_attention_flash_decoding_3d.cpp index a0aa41619d5..9d71733758d 100644 --- a/test/verify/test_attention_flash_decoding_3d.cpp +++ b/test/verify/test_attention_flash_decoding_3d.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -32,27 +32,38 @@ struct test_attention_flash_decoding_3d : verify_programadd_parameter("q", s_3d); - auto b = mm->add_parameter("k", s_3d); - auto b1 = mm->add_parameter("v", s_3d); - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), b1); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), gemm1); - rmax = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), - rmax); - auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); - auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); - auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), exp); - rsum = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s_3d.lens()}}), - rsum); - auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); - auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); + auto a = mm->add_parameter("q", q_shape); + auto b = mm->add_parameter("k", k_shape); + auto b1 = mm->add_parameter("v", v_shape); + + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); // {1, 64, 256} + auto rmax = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), + gemm1); // {1, 64, 1} + rmax = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), rmax); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, rmax); // {1, 64, 256} + auto exp = mm->add_instruction(migraphx::make_op("exp"), sub); // {1, 64, 256} + auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), + exp); // {1, 64, 1} + rsum = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 64, 256}}}), + rsum); // {1, 64, 256} + auto div = mm->add_instruction(migraphx::make_op("div"), exp, rsum); // {1, 64, 256} + auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), div, b1); // {1, 64, 32} mm->add_return({gemm2}); return p1; } diff --git a/test/verify/test_attention_flash_decoding_4d.cpp b/test/verify/test_attention_flash_decoding_4d.cpp index ffbe725b6e2..725cf296ae0 100644 --- a/test/verify/test_attention_flash_decoding_4d.cpp +++ b/test/verify/test_attention_flash_decoding_4d.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -62,7 +62,6 @@ struct test_attention_flash_decoding_4d : verify_program; -// template struct test_attention_flash_decoding_4d; -// template struct test_attention_flash_decoding_4d; +template struct test_attention_flash_decoding_4d; +template struct test_attention_flash_decoding_4d; +template struct test_attention_flash_decoding_4d;