Skip to content

[runtime][vm] Add LoRA adapter metadata to paged KV cache#18890

Open
MagellaX wants to merge 3 commits intoapache:mainfrom
MagellaX:feat/multilora-kvcache-pr1-apache
Open

[runtime][vm] Add LoRA adapter metadata to paged KV cache#18890
MagellaX wants to merge 3 commits intoapache:mainfrom
MagellaX:feat/multilora-kvcache-pr1-apache

Conversation

@MagellaX
Copy link

@MagellaX MagellaX commented Mar 8, 2026

Summary

This PR adds minimal runtime support for LoRA adapter metadata in paged KV cache state.

This is intended as a small runtime building block for downstream multi-LoRA serving work discussed here:

Changes

  • add lora_adapter_id to Sequence
  • add runtime KV-cache APIs to:
    • set sequence adapter id
    • get sequence adapter id
    • get current-batch adapter ids after BeginForward
  • make ForkSequence inherit the parent adapter id
  • populate current-batch adapter ids during BeginForward
  • add Python runtime tests covering:
    • default adapter id
    • set/get
    • fork inheritance
    • current-batch ordering

Scope

This is runtime-state plumbing only.

This PR does not yet add:

  • frontend request plumbing
  • LoRA runtime execution in model ops
  • multi-adapter batching logic
  • specialized LoRA kernels

Validation

Validated on a fresh apache/tvm base with a local source build with Modal...

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request establishes the foundational runtime support for LoRA adapter metadata within the paged KV cache. It provides the necessary plumbing to track and manage LoRA adapter IDs at the sequence level, which is a crucial building block for future multi-LoRA serving capabilities. The changes are focused purely on runtime state management and do not yet include frontend request handling, LoRA execution in model operations, or specialized batching logic.

Highlights

  • LoRA Adapter ID in Sequence: The Sequence struct now includes a lora_adapter_id field to associate a LoRA adapter with each sequence in the KV cache.
  • New KV Cache APIs: Introduced new runtime APIs to set and get the LoRA adapter ID for a specific sequence, and to retrieve the LoRA adapter IDs for the current batch after BeginForward.
  • LoRA ID Inheritance: When a sequence is forked, the child sequence now inherits the lora_adapter_id from its parent.
  • Batch LoRA ID Population: The current_lora_adapter_ids_host_ is populated with the adapter IDs of the sequences in the current batch during BeginForward.
  • Python Runtime Tests: Added comprehensive Python runtime tests to validate the default adapter ID, set/get functionality, fork inheritance, and correct ordering of current-batch adapter IDs.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/runtime/vm/attn_utils.h
    • Added lora_adapter_id to the Sequence struct.
  • src/runtime/vm/kv_state.cc
    • Registered new FFI methods for setting, getting, and retrieving current batch LoRA adapter IDs.
  • src/runtime/vm/kv_state.h
    • Declared virtual methods SetSequenceLoraAdapter, GetSequenceLoraAdapter, and GetCurrentLoraAdapterIds in AttentionKVCacheObj.
  • src/runtime/vm/paged_kv_cache.cc
    • Added current_lora_adapter_ids_host_ member to store LoRA adapter IDs for the current batch.
    • Initialized current_lora_adapter_ids_host_ during object construction.
    • Modified ForkSequence to ensure child sequences inherit the parent's lora_adapter_id.
    • Populated current_lora_adapter_ids_host_ with sequence LoRA adapter IDs during BeginForward.
    • Implemented SetSequenceLoraAdapter to update a sequence's LoRA adapter ID.
    • Implemented GetSequenceLoraAdapter to retrieve a sequence's LoRA adapter ID.
    • Implemented GetCurrentLoraAdapterIds to return the LoRA adapter IDs of the current batch.
  • tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
    • Declared and assigned global function variables for the new LoRA adapter APIs.
    • Added test_lora_adapter_metadata to verify LoRA adapter ID functionality, including default values, setting/getting, fork inheritance, and current batch ordering.
Activity
  • No human activity (comments, reviews) has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@MagellaX
Copy link
Author

MagellaX commented Mar 8, 2026

@babusid lmk your thoughts!!!

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for LoRA adapter metadata within the paged KV cache. The changes are well-structured, adding a lora_adapter_id to the Sequence struct, exposing new FFI functions to manage these IDs, and ensuring correct propagation during sequence forking. The implementation is clean and includes a comprehensive set of tests covering the new functionality. I have one minor suggestion to improve code reuse and simplify the implementation.

@babusid
Copy link
Contributor

babusid commented Mar 10, 2026

@babusid lmk your thoughts!!!

I've been mulling over what the best architecture for Multi-LoRA support might look like, and I don't think the original breakdown I gave you was the best. Sorry about that.

While the S-LoRA paper does use paged memory blocks to handle the adaptor weights, given that TVM's construct is explicitly a PagedKVCache, I don't think overloading it to hold adaptor weights is the cleanest.

I'm wondering if we're not better served by implementing a separate memory pool construct, in the same vein as PagedKVCache.cc or RNNState.cc, designed specifically for holding LoRA adaptor weights at runtime.

That said, by restructuring the efforts, we might be able to keep things strictly downstream (not completely sure on this last point).

Let me discuss a bit more internally with @MasterJH5574 and others, and see if I can better formalize an idea / design.
I also welcome your thoughts on the topic.

@MagellaX
Copy link
Author

@babusid lmk your thoughts!!!

I've been mulling over what the best architecture for Multi-LoRA support might look like, and I don't think the original breakdown I gave you was the best. Sorry about that.

While the S-LoRA paper does use paged memory blocks to handle the adaptor weights, given that TVM's construct is explicitly a PagedKVCache, I don't think overloading it to hold adaptor weights is the cleanest.

I'm wondering if we're not better served by implementing a separate memory pool construct, in the same vein as PagedKVCache.cc or RNNState.cc, designed specifically for holding LoRA adaptor weights at runtime.

That said, by restructuring the efforts, we might be able to keep things strictly downstream (not completely sure on this last point).

Let me discuss a bit more internally with @MasterJH5574 and others, and see if I can better formalize an idea / design. I also welcome your thoughts on the topic.

yeah, after digging through S-LoRA/Punica again and also looking at how the actual serving stacks handle this, I think you’re probably right....

The consistent pattern seems to be adapter identity may need to flow with request / cache correctness, but adapter weights themselves usually live in a separate adapter cache / pool, not inside the KV-cache object. S-LoRA’s unified paging reads more like a lower-level allocator/kernel strategy than the right API boundary. TRT-LLM is especially explicit about this since it keeps a LoRA cache distinct from the KV cache, and vLLM / SGLang also seem to converge on separate adapter management with adapter ids still affecting cache correctness.

So my current leaning is to avoid pushing actual LoRA weights into PagedKVCache, and instead re-scope around a separate LoRAPool / adapter-pool runtime object, with adapter id threaded separately. Then if unified paging still makes sense later, it can sit underneath both KV and adapter storage rather than overloading the KV abstraction.

@MasterJH5574
Copy link
Contributor

@MagellaX Thank you so much for the contributions! My overall read is that we probably need to first establish end-to-end LoRA serving flow, with runnable tests and real commands, before upstreaming parts. The main reason is that we don't want to iterate over the implementations for too many times in the mainline repo without seeing end-to-end effects.

@MagellaX
Copy link
Author

@MagellaX Thank you so much for the contributions! My overall read is that we probably need to first establish end-to-end LoRA serving flow, with runnable tests and real commands, before upstreaming parts. The main reason is that we don't want to iterate over the implementations for too many times in the mainline repo without seeing end-to-end effects.

yeah totally agree with that!!! i think with this PR the main useful outcome was clarifying the runtime boundary a bit, but I agree the next step should be downstream first, not upstream first. I’ll focus on getting a minimal end-to-end LoRA serving path working in MLC with real runnable commands, tests, and a clear single-adapter flow, and then only upstream the smallest TVM pieces that are actually required by that working path. That should make the design much easier to evaluate and avoid churning upstream abstractions too early.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants