diff --git a/include/infinicore/ops/mha_kvcache.hpp b/include/infinicore/ops/mha_kvcache.hpp new file mode 100644 index 000000000..2769e4e39 --- /dev/null +++ b/include/infinicore/ops/mha_kvcache.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include "../device.hpp" +#include "../graph/graph.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +// Flash Attention KV-cache decode op. +// +// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a +// paged KV cache. +// +// Tensor shapes: +// out : [batch_size, seqlen_q, num_heads, head_size] +// q : [batch_size, seqlen_q, num_heads, head_size] +// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout) +// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout) +// seqlens_k : [batch_size] int32 — total KV length per request +// block_table : [batch_size, max_num_blocks_per_seq] int32 + +INFINICORE_GRAPH_OP_CLASS( + MhaKVCache, + Tensor, // out + const Tensor &, // q + const Tensor &, // k_cache + const Tensor &, // v_cache + const Tensor &, // seqlens_k + const Tensor &, // block_table + std::optional, // alibi_slopes + float); // scale + +Tensor mha_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale); + +void mha_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale); + +} // namespace infinicore::op diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache.cc new file mode 100644 index 000000000..0c5b3ae8c --- /dev/null +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache.cc @@ -0,0 +1,58 @@ +#include "infinicore/ops/mha_kvcache.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MhaKVCache); + +MhaKVCache::MhaKVCache(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, seqlens_k, block_table); + INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(), + out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); +} + +void MhaKVCache::execute(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN( + MhaKVCache, + out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); +} + +void mha_kvcache_(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + MhaKVCache::execute(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); +} + +Tensor mha_kvcache(const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + // Output shape matches q: [batch_size, seqlen_q, num_heads, head_size] + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + mha_kvcache_(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale); + return out; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc new file mode 100644 index 000000000..d74fdbb00 --- /dev/null +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -0,0 +1,85 @@ +#include "infinicore/ops/mha_kvcache.hpp" + +#include "infinicore/adaptor/flash_attention_adaptor.hpp" + +namespace infinicore::op::mha_kvcache_impl::flashattn { + +struct PlannedMeta { + graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table; + std::optional alibi_slopes; + float scale; +}; + +void *plan(Tensor out, + const Tensor &q, + const Tensor &k_cache, + const Tensor &v_cache, + const Tensor &seqlens_k, + const Tensor &block_table, + std::optional alibi_slopes, + float scale) { + return new PlannedMeta{ + graph::GraphTensor(out), + graph::GraphTensor(q), + graph::GraphTensor(k_cache), + graph::GraphTensor(v_cache), + graph::GraphTensor(seqlens_k), + graph::GraphTensor(block_table), + alibi_slopes ? std::optional(graph::GraphTensor(*alibi_slopes)) : std::nullopt, + scale}; +} + +void run(void *planned_meta) { + c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); + auto *p = reinterpret_cast(planned_meta); + + auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + auto q = infinicore::adaptor::to_aten_tensor(p->q); + auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); + auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); + auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); + auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); + auto alibi_slopes = p->alibi_slopes + ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) + : std::nullopt; + + // No new KV tokens to append (pure decode, KV already written to cache). + std::optional k_new = std::nullopt; + std::optional v_new = std::nullopt; + std::optional rotary_cos = std::nullopt; + std::optional rotary_sin = std::nullopt; + std::optional cache_batch_idx = std::nullopt; + std::optional leftpad_k = std::nullopt; + + flash::mha_fwd_kvcache( + q, + k_cache, + v_cache, + k_new, + v_new, + seqlens_k, + rotary_cos, + rotary_sin, + cache_batch_idx, + leftpad_k, + block_table, + alibi_slopes, + out, + p->scale, + true, // is_causal + -1, // window_size_left (-1 = no sliding window) + -1, // window_size_right + 0.0f, // softcap + false, // is_rotary_interleaved + 0 // num_splits (0 = auto) + ); +} + +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MhaKVCache, &plan, &run, &cleanup); + +} // namespace infinicore::op::mha_kvcache_impl::flashattn