Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cli/alora/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
PreTrainedTokenizerBase,
TrainerCallback,
TrainerControl,
TrainerState,
Expand Down Expand Up @@ -47,7 +48,7 @@


def load_dataset_from_json(
json_path: str, tokenizer: AutoTokenizer, invocation_prompt: str
json_path: str, tokenizer: PreTrainedTokenizerBase, invocation_prompt: str
) -> Dataset:
"""Load a JSONL dataset and format it for SFT training.

Expand Down Expand Up @@ -218,7 +219,6 @@ def train_model(
base_model, padding_side="right", trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens = False

dataset = load_dataset_from_json(dataset_path, tokenizer, invocation_prompt)
dataset = dataset.shuffle(seed=42)
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/advanced/prefix-caching-and-kv-blocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ When a prompt contains a mix of cached and uncached blocks, Mellea:
2. Runs forward passes on uncached blocks.
3. Retrieves stored `DynamicCache` for cached blocks.
4. **Smashes** (concatenates) all KV caches along the time axis using
`merge_dynamic_caches()`.
`merge_dynamic_caches_v5()`.
5. Passes the merged cache plus the combined input IDs to the generation step.

The result is identical to a single full-context forward pass, with the prefill
Expand Down
6 changes: 3 additions & 3 deletions docs/kv_smash/kv_with_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches_v5
from mellea.backends.model_ids import IBM_GRANITE_4_HYBRID_MICRO

backend = LocalHFBackend(model_id=IBM_GRANITE_4_HYBRID_MICRO)
Expand Down Expand Up @@ -30,7 +30,7 @@ def cache(s: str, store=True) -> DynamicCache:
def merge(toks, dcs):
merged_toks = torch.cat([t["input_ids"] for t in toks], dim=1)
merged_masks = torch.cat([t["attention_mask"] for t in toks], dim=1)
merged_dcs = merge_dynamic_caches(dcs)
merged_dcs = merge_dynamic_caches_v5(dcs)

return merged_toks, merged_masks, merged_dcs

Expand Down Expand Up @@ -89,7 +89,7 @@ def merge(toks, dcs):
# Merge everything together.
merged_toks = torch.cat([toks["input_ids"] for toks in tok_parts], dim=1)
merged_masks = torch.cat([toks["attention_mask"] for toks in tok_parts], dim=1)
merged_dcs = merge_dynamic_caches(dc_parts)
merged_dcs = merge_dynamic_caches_v5(dc_parts)

# crop the last KV for safety.
merged_dcs.crop(-1)
Expand Down
17 changes: 10 additions & 7 deletions docs/kv_smash/kvcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
# "mellea[hf]",
# ]
# ///
from typing import cast

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers.generation import GenerateDecoderOnlyOutput
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from transformers.generation.utils import GenerateDecoderOnlyOutput
from transformers.modeling_utils import PreTrainedModel

from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches
from mellea.backends.kv_block_helpers import DynamicCache, merge_dynamic_caches_v5

model_id = "ibm-granite/granite-4.0-tiny-preview"
device = torch.device("mps")
model = AutoModelForCausalLM.from_pretrained(model_id)
model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(model_id) # type: ignore[assignment]
# model = model.to(device=device) # this part does not pass mypy; possible misconfiguration
tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(model_id)


def cache(toks) -> DynamicCache:
Expand All @@ -34,7 +37,7 @@ def merge(strs: list[str]):

merged_toks = torch.cat([toks["input_ids"] for toks in strs_toks], dim=1)
merged_masks = torch.cat([toks["attention_mask"] for toks in strs_toks], dim=1)
merged_dcs = merge_dynamic_caches(strs_dcs)
merged_dcs = merge_dynamic_caches_v5(strs_dcs)

return merged_toks, merged_masks, merged_dcs

Expand All @@ -45,7 +48,7 @@ def merge(strs: list[str]):
merged_dcs.crop(-1)

# GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput | GenerateBeamDecoderOnlyOutput | GenerateBeamEncoderDecoderOutput | LongTensor
result = model.generate(
result = model.generate( # type: ignore[operator]
merged_toks.to(model.device),
attention_mask=merged_masks.to(model.device),
past_key_values=merged_dcs,
Expand Down
2 changes: 1 addition & 1 deletion docs/metrics/coverage-current.json
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@
"TokenizedCacheIterleaving",
"LegacyCache",
"legacy_cache_smash",
"merge_dynamic_caches",
"merge_dynamic_caches_v5",
"tokens_to_legacy_cache"
],
"mellea.backends.huggingface.granite_formatters": [
Expand Down
25 changes: 16 additions & 9 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import json
import threading
from collections.abc import Callable, Coroutine, Sequence
from typing import Any, overload
from typing import Any, cast, overload

import llguidance
import llguidance.hf
Expand All @@ -24,7 +24,7 @@
from transformers.generation.streamers import AsyncTextIteratorStreamer
from transformers.generation.utils import GenerateDecoderOnlyOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_utils import set_seed

from ..backends import kv_block_helpers
Expand Down Expand Up @@ -71,7 +71,7 @@

Huggingface backends can initialize themselves from a model string if the transformers `Auto*` classes can be used. Therefore, a TransformersTorchConfig usually isn't required. However, sometimes a model needs special care to instantiate properly, or a custom device type needs to bse used. Instead of trying to do a lot of partial magic, we basically have two modaliites: either the constructor can figure out everything from the model_id, or the user has to provide an entire config.
"""
TransformersTorchConfig = tuple[PreTrainedTokenizer, PreTrainedModel, torch.device]
TransformersTorchConfig = tuple[PreTrainedTokenizerBase, PreTrainedModel, torch.device]

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors

Expand Down Expand Up @@ -302,8 +302,8 @@ def __init__(
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
self._hf_model_id, device_map=str(self._device), torch_dtype="auto"
)
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
self._hf_model_id
self._tokenizer: PreTrainedTokenizerBase = (
AutoTokenizer.from_pretrained(self._hf_model_id)
)
case _:
self._tokenizer, self._model, self._device = custom_config
Expand Down Expand Up @@ -726,7 +726,7 @@ def _make_merged_kv_cache(
[toks["attention_mask"] for toks in tok_parts], dim=1
)
assert input_ids.shape == attention_mask.shape
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches(dc_parts)
merged_cache: DynamicCache = kv_block_helpers.merge_dynamic_caches_v5(dc_parts)
# TODO: also assert that the merged cached is the correct shape given the input_ids and attention_mask shapes.
# rewind merged cache by 1 for safety.
merged_cache.crop(-1) # type: ignore
Expand Down Expand Up @@ -973,7 +973,8 @@ async def _generate_from_context_standard(
"", # Empty for no adapters.
self._model.generate, # type: ignore
# Passed as args/kwargs to generate.
input_ids,
inputs=input_ids["input_ids"],
attention_mask=input_ids["attention_mask"],
return_dict_in_generate=True,
use_cache=self._use_caches, # Only create KV cache if caching is enabled
**self._make_backend_specific_and_remove(generate_options),
Expand Down Expand Up @@ -1045,6 +1046,8 @@ async def processing(
input_ids: The prompt token IDs used for decoding; required to slice off
the prompt portion from the generated sequences.
"""
input_ids_tensor: torch.Tensor = input_ids["input_ids"]

if mot._underlying_value is None:
mot._underlying_value = ""

Expand All @@ -1055,8 +1058,12 @@ async def processing(
elif isinstance(chunk, GenerateDecoderOnlyOutput):
# Otherwise, it's a non-streaming request. Decode it here.
mot._meta["hf_output"] = chunk
mot._underlying_value += self._tokenizer.decode(
chunk.sequences[0, input_ids.shape[1] :], skip_special_tokens=True
mot._underlying_value += cast(
str,
self._tokenizer.decode(
chunk.sequences[0, input_ids_tensor.shape[1] :],
skip_special_tokens=True,
),
)

async def post_processing(
Expand Down
176 changes: 108 additions & 68 deletions mellea/backends/kv_block_helpers.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,120 @@
"""Low-level utilities for concatenating transformer KV caches (KV smashing).

Provides functions for merging ``DynamicCache`` and legacy tuple caches along the
time axis (``merge_dynamic_caches``, ``legacy_cache_smash``), and
time axis (``merge_dynamic_caches_v5``, ``legacy_cache_smash``), and
``tokens_to_legacy_cache`` for converting a tokenized prompt into a prefilled KV
cache. These helpers are used internally by local HuggingFace backends that reuse
cached prefix computations across multiple generation calls.
"""

from collections.abc import Iterable
from functools import reduce
from typing import Any
from typing import cast

import torch
from transformers import PreTrainedModel
from transformers.cache_utils import DynamicCache
from transformers.tokenization_utils_base import BatchEncoding

TokenizedCacheIterleaving = Iterable[BatchEncoding | DynamicCache]
LegacyCache = Any


def legacy_cache_smash(a: LegacyCache, b: LegacyCache) -> LegacyCache:
"""Concatenates two LegacyCache Ks and Vs along the time axis.

Args:
a: First legacy KV cache (tuple of per-layer (K, V) tensor pairs).
b: Second legacy KV cache to concatenate after ``a``.

Returns:
New legacy cache with ``b`` appended to ``a`` along the sequence dimension.
"""
legacy_merged = tuple(
(torch.cat([a[i][0], b[i][0]], dim=2), torch.cat([a[i][1], b[i][1]], dim=2))
for i in range(len(a))
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.cache_utils import CacheLayerMixin, DynamicCache
from transformers.generation.utils import GenerateDecoderOnlyOutput


@torch.no_grad()
def prefill_cache_v5(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
text: str,
device: torch.device,
) -> tuple[dict, DynamicCache]:
"""Prefills cache for transformers v5."""
toks = tokenizer(text, return_tensors="pt")
toks = {k: v.to(device) for k, v in toks.items()}

dc = DynamicCache()
out = model(
input_ids=toks["input_ids"],
attention_mask=toks["attention_mask"],
past_key_values=dc,
use_cache=True,
)
dc = out.past_key_values
dc.crop(-1)
return toks, dc # v5 returns DynamicCache (not legacy tuple)


def merge_dynamic_caches_v5(caches: Iterable[DynamicCache]) -> DynamicCache:
"""Merge multiple v5 DynamicCache objects by concatenating KV states along the time axis."""
caches = list(caches)
assert len(caches) >= 1

for c in caches:
if any(
getattr(layer, "is_sliding", False) for layer in getattr(c, "layers", [])
):
raise ValueError("Check the issue.")

merged = DynamicCache()

# reuse Cache.update() to append each segment's KV to the merged cache per layer.
# DynamicLayer.update(): self.keys = cat([self.keys, key_states], dim=-2).
for c in caches:
for layer_idx, layer in enumerate(c.layers):
if isinstance(layer, CacheLayerMixin):
if layer.keys is None or layer.values is None:
continue
merged.update(layer.keys, layer.values, layer_idx=layer_idx)

return merged


def merge_v5(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
strs: list[str],
device: torch.device,
):
"""Merges DynamicCache for transformers>=5.0.0."""
strs_toks, strs_dcs = [], []
for s in strs:
toks, dc = prefill_cache_v5(model, tokenizer, s, device)
strs_toks.append(toks)
strs_dcs.append(dc)

merged_toks = torch.cat([t["input_ids"] for t in strs_toks], dim=1)
merged_masks = torch.cat([t["attention_mask"] for t in strs_toks], dim=1)

merged_dc = merge_dynamic_caches_v5(strs_dcs)

return merged_toks, merged_masks, merged_dc


if __name__ == "__main__":
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.model_ids import IBM_GRANITE_3_3_8B

assert IBM_GRANITE_3_3_8B.hf_model_name is not None
backend = LocalHFBackend(model_id=IBM_GRANITE_3_3_8B.hf_model_name)
_model_raw, tokenizer, device = backend._model, backend._tokenizer, backend._device
model = cast(PreTrainedModel, _model_raw)

docs = [
"Nathan Fulton is expert in large language models, formal verification, and reinforcement learning. He holds a Ph.D. from Carnegie Mellon University's Computer Science Department and has worked at Amazon Web Services and IBM Research. He currently works at IBM Research - Cambridge.",
"IBM Research has a headquarters at 1101 Kitchawan Rd in Yorktown Heights and a Cambridge office at 314 Main Street in Cambridge, MA.",
"What is the address of Nathan's place of work?",
]

merged_tokens, merged_masks, merged_cache = merge_v5(
model, tokenizer, docs, device=backend._device
)
input_ids = merged_tokens.to(device)
generate_out = cast(
GenerateDecoderOnlyOutput,
model.generate( # type: ignore[operator]
input_ids=input_ids,
use_cache=True,
return_dict_in_generate=True,
past_key_values=merged_cache,
max_new_tokens=512,
),
)
result = tokenizer.decode(
generate_out.sequences[0, input_ids.shape[1] :], skip_special_tokens=True
)
return legacy_merged


def merge_dynamic_caches(caches: Iterable[DynamicCache]) -> DynamicCache:
"""Merges two DynamicCache Ks and Vs along the time axis.

Args:
caches: Iterable of ``DynamicCache`` objects to merge in order.

Returns:
A single ``DynamicCache`` with all caches concatenated along the sequence dimension.
"""
legacies = [c.to_legacy_cache() for c in caches] # type: ignore
assert len(legacies) >= 1
rv = DynamicCache.from_legacy_cache(reduce(legacy_cache_smash, legacies)) # type: ignore
return rv # type: ignore


def tokens_to_legacy_cache(
model: PreTrainedModel, device: str, tokens_or_cache: BatchEncoding | DynamicCache
) -> Iterable[LegacyCache]:
"""Prefills and returns Ks and Vs as a LegacyCache.

Args:
model: The HuggingFace model used for prefill.
device: Target device string (e.g. ``"cuda"``, ``"cpu"``).
tokens_or_cache: Either a ``BatchEncoding`` to prefill, or an existing
``DynamicCache`` to convert directly.

Returns:
Legacy KV cache representation as a tuple of per-layer (K, V) tensor pairs.
"""
if type(tokens_or_cache) is DynamicCache:
return tokens_or_cache.to_legacy_cache() # type: ignore
else:
tokens = tokens_or_cache
dc = DynamicCache()
with torch.no_grad():
dc = model(
tokens["input_ids"].to(device), # type: ignore
attention_mask=tokens["attention_mask"].to(device), # type: ignore
past_key_values=dc,
).past_key_values
return dc.to_legacy_cache()
print(result)
3 changes: 0 additions & 3 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@
convert_tools_to_json,
)

if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer

openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"

format: None = None # typing this variable in order to shadow the global format function and ensure mypy checks for errors
Expand Down
Loading
Loading