Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3a9cff6
initial
bdevorem Nov 19, 2025
8259f9a
auto split
bdevorem Nov 24, 2025
7895119
blah
bdevorem Nov 24, 2025
4a9e918
make 3d test more interesting
bdevorem Dec 4, 2025
d237085
remove rocmlir update since it is on mainline now
bdevorem Dec 4, 2025
b7e9f22
remove comments and clean up
bdevorem Dec 5, 2025
d7f59a1
Merge branch 'develop' into bdevorem/flash-decoding-r2
bdevorem Dec 5, 2025
2c20d78
AIMIGRAPHX-289 AIMIGRAPHX-341 ; remove comment
bdevorem Dec 5, 2025
6bda8f7
format
bdevorem Dec 5, 2025
cefa6e2
fix defaults, remove comments, clean up helper func
bdevorem Dec 5, 2025
c749ab3
update docs
bdevorem Dec 5, 2025
7eeacc3
AIMIGRAPHX-341 handle uneven splits
bdevorem Dec 5, 2025
2824c92
format
bdevorem Dec 5, 2025
23f5cbb
update doc
bdevorem Dec 5, 2025
ff755dd
format
bdevorem Dec 11, 2025
5d08ab0
add helper func for ceil; cursor made test file
bdevorem Dec 11, 2025
fe2cb8f
update helper func and comments; cursor made tests
bdevorem Dec 12, 2025
8e0d454
make split_dim by ref
bdevorem Dec 12, 2025
8a6dd3b
format
bdevorem Dec 12, 2025
dd7eb98
cursor made test cases to handle drop in codecov
bdevorem Dec 12, 2025
5b726f3
move group calculation into separate function
bdevorem Dec 12, 2025
86ebe6f
rework splits/groups math and member vars. need a member var to be ab…
bdevorem Dec 12, 2025
3705ded
fix busted cusor tests, clang-format & tidy
bdevorem Dec 12, 2025
8c888cc
remove stupid tests that cursor keeps adding back
bdevorem Dec 17, 2025
2111563
code cov tests; refactor a bit
bdevorem Dec 29, 2025
26324c8
format
bdevorem Dec 29, 2025
5162656
Merge branch 'develop' into bdevorem/flash-decoding-r2
bdevorem Dec 29, 2025
de3eb6b
tidy
bdevorem Dec 30, 2025
9cfd664
add more tests for codecov and fix bug in padding
bdevorem Dec 31, 2025
7173c41
format
bdevorem Dec 31, 2025
b2cd8ad
change a loop, disable flash decoding on mi300
bdevorem Jan 1, 2026
19c7ede
happy new yeargit add -u!
bdevorem Jan 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions docs/reference/MIGraphX-dev-env-vars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,49 @@ Model performance tunable variables change the compilation behavior of a model.

| Default: Split-k performance configurations are turned off.

* - | ``MIGRAPHX_FLASH_DECODING_ENABLED``
- | When set, flash decoding optimization for attention fusion is enabled, which splits the key-value sequence dimension for improved performance on long sequences.
- | ``1``: Enables flash decoding optimization.
| ``0``: Disables flash decoding optimization.

| Default: ``0`` (disabled).

* - | ``MIGRAPHX_FLASH_DECODING_NUM_SPLITS``
| Turns on flash decoding for attention fusion and sets the number of splits along the key-value sequence dimension.
| Sets the number of splits along the key-value sequence dimension when flash decoding is enabled.

- | ``0`` or negative: Automatically calculates optimal splits based on sequence length.
| ``N`` (where N > 0): Uses exactly N splits along the key-value sequence dimension.

| Default: ``0`` (automatic calculation).

| Note: This variable implicitly sets ``MIGRAPHX_FLASH_DECODING_ENABLED=1``. If not set, and ``MIGRAPHX_FLASH_DECODING_ENABLED=1``, the number of splits will be automatically calculated.

* - | ``MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE``
| Sets the minimum chunk size for automatic split calculation in flash decoding.

- | ``0``: Flash decoding is turned off (i.e., number of splits is 0).
| ``N`` (where N > 1): Enables flash decoding with N splits along the key-value sequence dimension. For example, ``2`` enables flash decoding with 2 splits, ``4`` with 4 splits, etc.
- | Takes a positive integer representing the minimum size of each chunk after splitting.

| Default: ``32``.

| Note: Only used when automatic split calculation is enabled.

| Default: flash decoding is turned off.
* - | ``MIGRAPHX_FLASH_DECODING_MAX_SPLITS``
| Sets the maximum number of splits allowed during automatic split calculation.

- | Takes a positive integer representing the maximum number of splits.

| Default: ``16``.

| Note: Only used when automatic split calculation is enabled.

* - | ``MIGRAPHX_FLASH_DECODING_THRESHOLD``
| Sets the minimum sequence length threshold for flash decoding to be applied.

- | Takes a positive integer. Flash decoding is only applied when the sequence length is greater than or equal to this threshold.

| Default: ``32``.

| Note: Only used when automatic split calculation is enabled.

* - | ``MIGRAPHX_DISABLE_FP16_INSTANCENORM_CONVERT``
| When set, FP16 is not converted to FP32 in the ``InstanceNormalization`` ONNX operator.
Expand Down
208 changes: 175 additions & 33 deletions src/fuse_attention.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,7 +28,9 @@
#include <migraphx/matcher.hpp>
#include <migraphx/match/softmax.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/generic_float.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/split_factor.hpp>
#include <queue>
#include <optional>

Expand All @@ -37,9 +39,81 @@ inline namespace MIGRAPHX_INLINE_NS {

namespace {

// env vars for flash decoding configuration
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_NUM_SPLITS);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_MAX_SPLITS);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_FLASH_DECODING_THRESHOLD);

// Helper function to get config value with priority: struct member (if not default) > env var >
// default
template <typename EnvVar>
std::size_t get_config_value(std::size_t struct_value, std::size_t default_value, EnvVar env_var)
{
// if struct member is not the default value, use it
if(struct_value != default_value)
{
return struct_value;
}

// otherwise return env var value, or default if not set
return value_of(env_var, default_value);
}

// Get num_splits with priority: struct member > env var > 0 (not set)
std::size_t get_num_splits(std::size_t member_num_splits)
{
// if struct member is set (non-zero), use it
if(member_num_splits > 0)
{
return member_num_splits;
}

// otherwise return env var value, or 0 if not set
return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0);
}

std::size_t get_num_splits() { return value_of(MIGRAPHX_FLASH_DECODING_NUM_SPLITS{}, 0); }
// calculate optimal flash decoding splits
inline std::size_t calculate_flash_decoding_splits(std::size_t sequence_length,
std::size_t min_chunk_size,
std::size_t max_splits)
{
std::size_t r = sequence_length;
return split_dim(r, min_chunk_size, max_splits);
}

// calculate the actual number of groups for flash decoding
// returns 0 if no splitting should be performed
inline std::size_t calculate_groups(std::size_t groups,
std::size_t sequence_length,
std::size_t threshold,
std::size_t min_chunk_size,
std::size_t max_splits)
{
// if groups is explicitly set and valid, use it
if(groups > 1)
return groups;

// if groups is 0, auto-calculate based on sequence length
if(groups == 0)
{
// skip if sequence is too short
if(sequence_length < threshold)
return 0;

std::size_t actual_groups =
calculate_flash_decoding_splits(sequence_length, min_chunk_size, max_splits);

// return 0 if auto-calculation determines no splitting needed
if(actual_groups <= 1)
return 0;

return actual_groups;
}

// groups == 1 or invalid, no splitting
return 0;
}

// TODO: Write this in matcher.hpp as a general matcher for iterating through inputs
inline auto pointwise_inputs()
Expand Down Expand Up @@ -216,8 +290,11 @@ struct find_attention

struct find_flash_decoding
{
// number of groups. User-provided for now
std::size_t groups;
// configuration from fuse_attention pass config
std::size_t configured_splits;
std::size_t configured_threshold;
std::size_t configured_max_splits;
std::size_t configured_min_chunk_size;

auto matcher() const
{
Expand Down Expand Up @@ -267,7 +344,8 @@ struct find_flash_decoding
std::vector<size_t> v_shape; // final V shape: [B, G, N/G, D]
};

transformed_shapes_result get_transformed_shapes(const std::vector<shape>& input_shapes) const
transformed_shapes_result get_transformed_shapes(const std::vector<shape>& input_shapes,
std::size_t num_groups) const
{
assert(input_shapes.size() == 3 and "Expected Q, K, V shapes");

Expand All @@ -279,11 +357,11 @@ struct find_flash_decoding
// 4D: Q_lens = [B, H, M, k]
size_t ndim = q_lens.size();
size_t n = k_lens[ndim - 1];
size_t g = groups;
size_t g = num_groups;

// TODO: handle uneven splits; this is caught in `apply` for now
assert(n % g == 0 and
"Key-value sequence length must be divisible by number of splits/groups");
// Note: sequence length may have been padded to be divisible by num_groups
assert(n % g == 0 and "Key-value sequence length must be divisible by number of "
"splits/groups (after padding)");
size_t n_split = n / g;

transformed_shapes_result result;
Expand Down Expand Up @@ -449,15 +527,31 @@ struct find_flash_decoding
assert(k_param->name() == "@param" and "K should be a parameter");
assert(v_param->name() == "@param" and "V should be a parameter");

// check if N dimension is evenly divisible by num_splits
if(k_param->get_shape().lens().back() % groups != 0)
// Get sequence length from K shape
auto k_shape = k_param->get_shape();
std::size_t sequence_length = k_shape.lens().back();

// read configuration with priority: struct member (if not default) > env var > default
std::size_t groups = get_num_splits(configured_splits);
std::size_t threshold =
get_config_value(configured_threshold, 32, MIGRAPHX_FLASH_DECODING_THRESHOLD{});
std::size_t min_chunk_size = get_config_value(
configured_min_chunk_size, 32, MIGRAPHX_FLASH_DECODING_MIN_CHUNK_SIZE{});
std::size_t max_splits =
get_config_value(configured_max_splits, 16, MIGRAPHX_FLASH_DECODING_MAX_SPLITS{});

std::size_t actual_groups =
calculate_groups(groups, sequence_length, threshold, min_chunk_size, max_splits);
if(actual_groups == 0)
return;

// get Q, V, K shapes from gemms
auto qkv_shapes = get_qkv_shapes(q_param, k_param, v_param);

// check shapes are ok and get flash decoding transformed shapes (Q', V', K')
auto transform_info = get_transformed_shapes(qkv_shapes);
// calculate padding if sequence length not evenly divisible
std::size_t padding_needed = 0;
if(sequence_length % actual_groups != 0)
{
// round up to nearest multiple of actual_groups
padding_needed = ceil_mul_of(sequence_length, actual_groups) - sequence_length;
}

// create mapping from submodule params to main module inputs
auto group_inputs = attn_group_ins->inputs();
Expand All @@ -468,6 +562,39 @@ struct find_flash_decoding
auto k = map_param_to_main.at(k_param);
auto v = map_param_to_main.at(v_param);

// save original references before padding (needed for group_inputs replacement later)
auto q_orig = q;
auto k_orig = k;
auto v_orig = v;

// pad Q, K and V if necessary
if(padding_needed > 0)
{
// Q shape: [B, M, k] or [B, H, M, k] for 4D. Padding on M (sequence length dim)
auto q_ndim = q->get_shape().ndim();
std::vector<std::size_t> q_pads(2 * q_ndim, 0);
q_pads[q_ndim + q_ndim - 2] = padding_needed; // pad right on M dim (second to last)
q = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", q_pads}}), q);

// K shape: [B, k, N] or [B, H, k, N] for 4D. Padding on N
auto k_ndim = k->get_shape().ndim();
std::vector<std::size_t> k_pads(2 * k_ndim, 0);
k_pads[k_ndim + k_ndim - 1] = padding_needed; // pad right on last dim
k = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", k_pads}}), k);

// V shape: [B, N, D] or [B, H, N, D] for 4D
auto v_ndim = v->get_shape().ndim();
std::vector<std::size_t> v_pads(2 * v_ndim, 0);
v_pads[v_ndim + v_ndim - 2] = padding_needed; // pad right on N dim
v = mm.insert_instruction(attn_group_ins, make_op("pad", {{"pads", v_pads}}), v);
}

// get Q, K, V shapes (using potentially padded K and V)
auto qkv_shapes = get_qkv_shapes(q, k, v);

// check shapes are ok and get flash decoding transformed shapes (Q', V', K')
auto transform_info = get_transformed_shapes(qkv_shapes, actual_groups);

// insert reshape operations before group, for Q, K, V
auto q_ndim = q->get_shape().lens().size();
int64_t g_axis = q_ndim - 2;
Expand All @@ -493,18 +620,19 @@ struct find_flash_decoding
attn_group_ins, make_op("reshape", {{"dims", transform_info.v_shape}}), v);

// create new input list by replacing Q, K, V with reshaped versions
// use original references (before padding) for comparison
std::vector<instruction_ref> new_group_inputs = group_inputs;
for(size_t i = 0; i < group_inputs.size(); ++i)
{
if(group_inputs[i] == q)
if(group_inputs[i] == q_orig)
{
new_group_inputs[i] = q_reshaped;
}
else if(group_inputs[i] == k)
else if(group_inputs[i] == k_orig)
{
new_group_inputs[i] = k_reshaped;
}
else if(group_inputs[i] == v)
else if(group_inputs[i] == v_orig)
{
new_group_inputs[i] = v_reshaped;
}
Expand Down Expand Up @@ -604,8 +732,26 @@ struct find_flash_decoding
auto final_squeezed_o = mm.insert_instruction(
attn_group_ins, make_op("squeeze", {{"axes", {g_axis}}}), final_output_o);

// if padding was applied, slice to remove it
instruction_ref final_result = final_squeezed_o;
if(padding_needed > 0)
{
// need to slice the sequence dimension to remove padding
// final_squeezed_o has shape like [B, M_padded, D], need to slice M back to original
auto output_shape = final_squeezed_o->get_shape();
const auto& output_lens = output_shape.lens();
std::size_t seq_dim_idx = output_lens.size() - 2; // sequence dim is second to last
std::size_t original_seq_len = output_lens[seq_dim_idx] - padding_needed;

final_result = mm.insert_instruction(
attn_group_ins,
make_op("slice",
{{"axes", {seq_dim_idx}}, {"starts", {0}}, {"ends", {original_seq_len}}}),
final_squeezed_o);
}

// replace the original group instruction with the final result
mm.replace_instruction(attn_group_ins, final_squeezed_o);
mm.replace_instruction(attn_group_ins, final_result);
}
};

Expand Down Expand Up @@ -797,20 +943,16 @@ void fuse_attention::apply(module_pass_manager& mpm) const
mpm.run_pass(dead_code_elimination{});
}

std::size_t num_splits = 0;
if(flash_decoding_num_splits.has_value())
{
// Use the value from the constructor (for testing)
num_splits = *flash_decoding_num_splits;
}
else
{
// Default behavior: read from the env var (for non-test use)
num_splits = get_num_splits();
}
if(num_splits > 1)
// enable flash decoding if splits configured or flash decoding is enabled
std::size_t configured_splits = get_num_splits(flash_decoding_num_splits);
if(configured_splits > 0 or flash_decoding_enabled)
{
match::find_matches(mpm, find_flash_decoding{.groups = num_splits});
match::find_matches(
mpm,
find_flash_decoding{.configured_splits = flash_decoding_num_splits,
.configured_threshold = flash_decoding_threshold,
.configured_max_splits = flash_decoding_max_splits,
.configured_min_chunk_size = flash_decoding_min_chunk_size});
mpm.run_pass(dead_code_elimination{});
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/include/migraphx/fuse_attention.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -38,7 +38,11 @@ struct module_pass_manager;
struct MIGRAPHX_EXPORT fuse_attention
{
bool attn_enabled = false;
std::optional<std::size_t> flash_decoding_num_splits = std::nullopt;
bool flash_decoding_enabled = false;
std::size_t flash_decoding_num_splits = 0;
std::size_t flash_decoding_threshold = 32;
std::size_t flash_decoding_max_splits = 16;
std::size_t flash_decoding_min_chunk_size = 32;

std::string name() const { return "fuse_attention"; }
void apply(module_pass_manager& mpm) const;
Expand Down
Loading
Loading