Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions include/infinicore/ops/mha_kvcache.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#pragma once

#include "../device.hpp"
#include "../graph/graph.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

// Flash Attention KV-cache decode op.
//
// Wraps flash::mha_fwd_kvcache for single-step (decode) attention over a
// paged KV cache.
//
// Tensor shapes:
// out : [batch_size, seqlen_q, num_heads, head_size]
// q : [batch_size, seqlen_q, num_heads, head_size]
// k_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// v_cache : [num_blocks, block_size, num_heads_k, head_size] (paged layout)
// seqlens_k : [batch_size] int32 — total KV length per request
// block_table : [batch_size, max_num_blocks_per_seq] int32

INFINICORE_GRAPH_OP_CLASS(
MhaKVCache,
Tensor, // out
const Tensor &, // q
const Tensor &, // k_cache
const Tensor &, // v_cache
const Tensor &, // seqlens_k
const Tensor &, // block_table
std::optional<Tensor>, // alibi_slopes
float); // scale

Tensor mha_kvcache(const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale);

void mha_kvcache_(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale);

} // namespace infinicore::op
58 changes: 58 additions & 0 deletions src/infinicore/ops/mha_kvcache/mha_kvcache.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include "infinicore/ops/mha_kvcache.hpp"
#include "../../utils.hpp"

namespace infinicore::op {

INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(MhaKVCache);

MhaKVCache::MhaKVCache(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, seqlens_k, block_table);
INFINICORE_GRAPH_OP_DISPATCH(out->device().getType(),
out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}

void MhaKVCache::execute(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
INFINICORE_GRAPH_OP_RECORD_OR_RUN(
MhaKVCache,
out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}

void mha_kvcache_(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
MhaKVCache::execute(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
}

Tensor mha_kvcache(const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
// Output shape matches q: [batch_size, seqlen_q, num_heads, head_size]
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
mha_kvcache_(out, q, k_cache, v_cache, seqlens_k, block_table, alibi_slopes, scale);
return out;
}

} // namespace infinicore::op
85 changes: 85 additions & 0 deletions src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include "infinicore/ops/mha_kvcache.hpp"

#include "infinicore/adaptor/flash_attention_adaptor.hpp"

namespace infinicore::op::mha_kvcache_impl::flashattn {

struct PlannedMeta {
graph::GraphTensor out, q, k_cache, v_cache, seqlens_k, block_table;
std::optional<graph::GraphTensor> alibi_slopes;
float scale;
};

void *plan(Tensor out,
const Tensor &q,
const Tensor &k_cache,
const Tensor &v_cache,
const Tensor &seqlens_k,
const Tensor &block_table,
std::optional<Tensor> alibi_slopes,
float scale) {
return new PlannedMeta{
graph::GraphTensor(out),
graph::GraphTensor(q),
graph::GraphTensor(k_cache),
graph::GraphTensor(v_cache),
graph::GraphTensor(seqlens_k),
graph::GraphTensor(block_table),
alibi_slopes ? std::optional<graph::GraphTensor>(graph::GraphTensor(*alibi_slopes)) : std::nullopt,
scale};
}

void run(void *planned_meta) {
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
auto alibi_slopes = p->alibi_slopes
? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes))
: std::nullopt;

// No new KV tokens to append (pure decode, KV already written to cache).
std::optional<const at::Tensor> k_new = std::nullopt;
std::optional<const at::Tensor> v_new = std::nullopt;
std::optional<const at::Tensor> rotary_cos = std::nullopt;
std::optional<const at::Tensor> rotary_sin = std::nullopt;
std::optional<const at::Tensor> cache_batch_idx = std::nullopt;
std::optional<const at::Tensor> leftpad_k = std::nullopt;

flash::mha_fwd_kvcache(
q,
k_cache,
v_cache,
k_new,
v_new,
seqlens_k,
rotary_cos,
rotary_sin,
cache_batch_idx,
leftpad_k,
block_table,
alibi_slopes,
out,
p->scale,
true, // is_causal
-1, // window_size_left (-1 = no sliding window)
-1, // window_size_right
0.0f, // softcap
false, // is_rotary_interleaved
0 // num_splits (0 = auto)
);
}

void cleanup(void **planned_meta_ptr) {
delete *reinterpret_cast<PlannedMeta **>(planned_meta_ptr);
*planned_meta_ptr = nullptr;
}

INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(MhaKVCache, &plan, &run, &cleanup);

} // namespace infinicore::op::mha_kvcache_impl::flashattn
Loading