Conversation
* Replace control-token KV hiding with token-exchange by default (#8) Adapter control tokens were padded into every Q/K/V in every decoder layer via `control_dims=32` and masked with `finfo.min`. This bloated the KV cache head_dim by 25-50% and forced FlashAttention onto padded 160/192-wide vectors when only `num_hiding_groups` (typically 1) of the 32 extra dims were ever non-zero. Switch to token-exchange: after the switch reads `input_ids` and detects which adapter to activate, replace each control token's embedding with a substitute real-token embedding before the decoder runs. Control tokens become ordinary content tokens in the residual stream and `control_dims` collapses to 0, dropping the expansion entirely. Substitute ids are computed at compose time: - ALoRA adapters -> first token of alora_invocation_tokens - LoRA/builtin adapters -> tokenizer.bos_token_id New config field `adapter_substitute_token_ids` is persisted in config.json and drives a `use_token_exchange` property read by both backends. Default `control_dims` flips from 32 to 0. The legacy KV-hiding path is preserved as an opt-in escape hatch via the new `--legacy-hiding` composer flag; any adapter that regresses under token-exchange can be composed with the old semantics unchanged. Key validation: - Reject num_adapters>0 with neither hiding nor substitute ids (would leak raw control-token embeddings into attention). - Reject duplicate adapter_token_ids (LUT collision). - Reject negative / wrong-length substitute ids. Position correction via `hidden_count` is skipped in token-exchange mode since control tokens are real positions. Design: docs/KV_CACHE_OVERHEAD_REMOVAL.md Tracks issue #8. * Add token-exchange parity eval harness (#8) Measures four metrics per position, teacher-forced, across a list of prompts to compare legacy KV-hiding (control_dims>0) vs. token-exchange (control_dims=0 + substitute ids): 1. KL(p_old || p_new) per position (log_softmax based to avoid underflow) 2. Top-1 agreement (tagged "(noisy)" on wide nuclei) 3. Nucleus (top-p=0.9) Jaccard (sampling-set overlap) 4. Mass under old nucleus by new (the actionable gate) Results are partitioned into overall / pre-control / adapter-active buckets. The pre-control bucket must be bit-for-bit identical (KL max == 0, top-1 agree == 1.0); any drift there signals a bug in the embedding-swap gating rather than a mode trade-off. Two modes: - Synthetic (CPU): builds two HF models with identical base weights, one in legacy hiding and one in token-exchange. Useful as a plumbing check and regression guard. Runs as a standard pytest. - Real-model (GPU, opt-in): set GRANITE_SWITCH_PARITY_MODELS='{"old":..., "new":...}'. Loads composed checkpoints and uses demo-script prompts (14 adapter-specific prompts from run_adapter_generation_direct.py) rendered through the composed tokenizer's chat template. Thresholds: top-1 >= 0.95, mean KL <= 0.02, mean mass-under-old-nucleus >= 0.88. Also exposes build_demo_prompts() in the demo script. Short-circuits _generate via a module-level capture flag so prompt text is collected without touching model.generate. Used by the parity eval to pull realistic adapter inputs without duplicating the demo prompt data. CLI usage: python -m tests.integration.test_token_exchange_parity \ --old /path/to/legacy_build --new /path/to/te_build --json-out report.json * Use <|start_of_role|> instead of BOS for LoRA/builtin substitute (#8) Granite tokenizers alias bos_token_id to <|end_of_text|> (EOS), so the previous BOS-based substitute for LoRA/builtin adapters would have injected an end-of-text signal mid-prompt — a stop-generation marker in a place the model was not trained to see it. The chat template places the LoRA control token at sequence start, immediately followed by <|start_of_role|>user<|end_of_role|>... — so <|start_of_role|> is the deterministic "token that naturally follows" for every LoRA adapter, and its embedding is well-trained in the base model (part of the base vocab on Granite 4.0 and 4.1). Parallels the ALoRA path (substitute = first invocation token). Both paths now pick "the token that comes right after the control token in the rendered chat prompt" — single principle, two sources. Validated: - tokenizer.convert_tokens_to_ids('<|start_of_role|>') == 100264 on ibm-granite/granite-4.1-3b and granite-4.0-micro (part of base vocab, not composer-added). - bos_token_id == eos_token_id == 100257 ('<|end_of_text|>') on all three Granite tokenizers tested — confirming the prior default was semantically wrong. * Template: drop the first role marker after each adapter control token (#8) The runtime embedding swap replaces each adapter control token's embedding with a substitute token's embedding — for LoRA adapters this is <|start_of_role|>, for assistant-boundary ALoRA adapters it's also <|start_of_role|> (the first token of their invocation sequence). But the chat template then emits a *real* <|start_of_role|> at the next position: the user or assistant role marker that naturally follows the control-token prefix. Result before this change: two consecutive positions carrying <|start_of_role|>'s embedding. The model has never seen that pattern during pretraining — a duplicate-embedding OOD right at the start of the decoder's residual stream. Fix: add a skip-once Jinja flag (ns.skip_next_start_of_role). Arm it when lora_prefix_insertion emits the LoRA control token, or when alora_insertion fires the fallback path for assistant-boundary ALoRAs. Wrap every <|start_of_role|> emission in the base Granite template with a skip-once block that consumes the flag. The flag is single-shot — only the very first <|start_of_role|> after the control token is suppressed; all later role markers emit normally. Not addressed in this PR: ALoRAs whose invocation text is in-message content text (<requirements>, <guardian>, <certainty>). The first token of these invocations is the single character '<', and the rest of the invocation text cannot be cleanly sliced at the template level without changing what 'requirements>' (or 'guardian>', etc.) tokenizes to. Those adapters retain the duplicate-embedding pattern until a runtime-level drop lands in a follow-up. Backward compatibility: old checkpoints (composed before this change) load unchanged — the template edit only runs at compose time and affects only newly-composed models. Their rendered output for LoRA and assistant-boundary ALoRA is now one token shorter than before (the suppressed <|start_of_role|>). Update the three test_chat_template tests whose assertions encoded the old contract. * Template: drop first char of in-message ALoRA invocation (#8) Closes the remaining duplicate-embedding OOD at the swap site. Complements the skip-once <|start_of_role|> edit from the previous commit by extending the same principle to ALoRA adapters whose invocation text lives inside a user message (<requirements>, <certainty>, <guardian>, <context>, etc.). Change: in alora_pass2, after inserting the control token before the invocation text, also drop the first CHARACTER of the invocation text. Example: "Please <|req_check|><requirements>" becomes "Please <|req_check|>requirements>". At runtime the embedding-swap replaces the control token's embedding with the first invocation token's embedding — the embedding of '<'. The decoder then sees [<|req_check|>→e_<, requirements, >] — exactly what "<requirements>" tokenizes to in isolation, with no duplicate. Why this is safe on the Granite tokenizer: verified empirically via a new property test (test_first_char_drop_equals_first_token_drop). For every ALoRA invocation in the standard library, tokenizing the full invocation and dropping the first token ID yields the same sequence as tokenizing the string with its first character removed. BPE's greedy- merge would break this property if the second-byte merges depended on the leading '<'; it doesn't, because '<' tokenizes as its own single- character token in every case. The accompanying test test_first_token_is_single_character asserts the complementary invariant: the first token of each invocation decodes to exactly one character. If a future invocation text starts with a multi-character first token, that test catches it — the Jinja edit (invocation_text[1:] drops one character) would otherwise silently produce a wrong-length drop. Combined with the previous commit (skip-once <|start_of_role|>), the duplicate-embedding pattern is now eliminated across all adapter types in the Granite adapter library: LoRA, assistant-boundary ALoRA, and in-message ALoRA. * Derive LoRA substitute from the tokenizer's chat template (#8) Previously the composer hardcoded _LORA_SUBSTITUTE_TOKEN = "<|start_of_role|>". That's the right answer for Granite 4.x but it ties the default-path composer to a Granite-specific token name. Any base model with a different chat template (different role marker, different turn-open convention) would silently get the wrong substitute — a token the base model knows, but not the one sitting at position 1 of its rendered prompt. Replace the hardcode with a compose-time probe: render a minimal no-adapter user turn through tokenizer.apply_chat_template, tokenize, and read input_ids[0]. That's by construction whatever the template emits at the start of a normal turn, which is exactly what sits at position 1 after a LoRA-prepended control token. The substitute and the template's own behavior are now derived from the same source of truth. Verified: the probe returns 100264 (<|start_of_role|>) on granite-4.1-3b, granite-4.0-micro, and granite-switch-4.1-3b-preview — identical to the previous hardcoded value. Behavior on Granite is unchanged; the door is open for non-Granite base models. Error paths give actionable messages: - Tokenizer has no chat_template → suggest --legacy-hiding - Template render fails → report the Jinja error, suggest --legacy-hiding - First token is <unk> → report that the template emits something outside the vocab - Probe returns an empty id list → same Tests: - tests/composer/test_lora_substitute_probe.py (7 cases): * Real tokenizer round-trip on granite-4.1-3b and 4.0-micro * Synthetic tokenizer with a non-Granite template returns the custom template's first-token id * All four error paths raise ValueError with matching messages * Move token-exchange rewrite into the switch (#8) Refactor: the runtime substitution LUT and the embedding-swap step move out of each backend's decoder and into SingleSwitch (HF + vLLM). The switch now performs both halves of token-exchange: 1. Adapter selection — read input_ids, detect control tokens via input_ids == adapter_token_ids, emit per-token adapter_indices (unchanged). 2. Token rewrite — replace each control token's id in input_ids with its substitute id (from a switch-owned LUT). New. The switch's forward signature changes from -> adapter_indices to -> (adapter_indices, modified_input_ids) The decoder consumes both: adapter_indices feeds the LoRA layers as before, modified_input_ids feeds embed_tokens / get_input_embeddings exactly once. There is no longer a decoder-side LUT, no scatter, no clone-guard, no use_token_exchange branch in the embedding path. Why this is cleaner: - Single source of truth for the substitution. The switch already knows which positions are control tokens; rewriting input_ids at those positions is a natural extension of "decide which adapter is active." The decoder is genuinely token-exchange-agnostic — it just embeds whatever input_ids it receives. - HF and vLLM converge to the same control flow. Both backends now call switch(...), unpack two outputs, embed once. Previously each backend had a near-identical but layout-specific (B,S,H vs T,H) embedding-swap block + clone-guard that needed to be maintained separately. - Smaller diff for any future change to the substitution logic. Whether to ship a different substitute strategy (e.g. learned embedding, per-adapter rules) becomes a one-place change in the switch instead of a two-place change across both decoders. HF model forward also reorders slightly: switch runs before embed_tokens, so we embed exactly once on modified_input_ids. create_causal_mask now receives a stub embedding tensor of the right shape and dtype (it only uses the tensor for batch/query/dtype inference per the upstream docstring), since the real embedding hasn't been computed yet. Tests: - tests/hf/test_single_switch.py: _run helper unpacks the new tuple return; TestBatchProcessing similarly. - tests/hf/test_token_exchange.py: LUT presence assertion now reads model.model.switch.control_to_substitute_lut instead of model.model.control_to_substitute_lut. No behavior change verified by 756 passing tests (= same count as before the refactor; +0 -0 after fixture updates). * Remove the legacy KV-hiding code path (#8) Token-exchange has been the default for several commits. This change deletes the dead-but-still-callable KV-hiding code path entirely: Config: - Drop control_dims, hiding_groups, hiding_policy, adapter_third_party parameters and the corresponding state. - Drop expanded_head_dim, num_hiding_groups, hiding_group_names, use_token_exchange properties (token-exchange is now always on when num_adapters > 0). - Drop get_hiding_group_token_ids, get_third_party_adapter_mask, get_adapter_hiding_policy_matrix methods. - adapter_substitute_token_ids becomes required when num_adapters > 0. - Net: -150 LoC (config.py 345 → 195). Models: - HF and vLLM both drop token_to_group_mask / adapter_hiding_matrix buffers, hidden_count / adjusted_position_ids logic, and the token_group_membership / query_group_suppression plumbing through decoder layers. - The HF decoder layer's forward signature drops two kwargs. Attention layers (hf/core/lora.py, vllm/core/decoder.py): - Drop expand_control_dims / control_dims / expanded_head_dim fields. - Delete _expand_with_control_dimensions method entirely (~85 LoC each). - Delete the expansion / trim-back branches in forward. - vllm/core/decoder.py: attn_head_dim is unconditionally head_dim. Switches: - Drop config.expanded_head_dim references; head_dim is config.projection_head_dim everywhere. vllm/__init__.py: - ModelArchConfigConvertor.get_head_size() returns config.projection_head_dim (no expansion logic). Composer: - compose_granite_switch.py: drop --control-dims and --legacy-hiding CLI flags. Delete the legacy-hiding branch in build(); always token-exchange. - compose_utils.py: drop hiding_groups / hiding_policy / adapter_third_party kwargs. - model_card.py: drop control_dims / legacy_hiding / use_token_exchange reporting fields. Tests deleted entirely: - tests/unit/test_hiding_constant.py - tests/hf/test_kv_hiding_gap_equivalence.py - tests/vllm/test_kv_hiding_gap_equivalence.py - tests/vllm/_kv_hiding_gap_tests.py - tests/hf/test_position_zero_nan.py - tests/vllm/_position_zero_nan_tests.py - tests/integration/test_token_exchange_parity.py (compared old vs new modes; with no old mode, nothing to compare). - tests/composer/test_built_in_adapters.py (entire file tested removed Mode A / Mode B distinction). Tests rewritten: - tests/conftest.py, tests/unit/test_config{,_edge_cases}.py, tests/unit/test_token_exchange.py, tests/hf/test_model_forward.py, tests/hf/test_token_exchange.py, tests/hf/test_qk_norm.py, tests/shared/granite4_equivalence.py, tests/shared/generation_models.py: fixtures and assertions updated for the simpler config surface. Net diff: ~3000 LoC deleted, ~200 LoC added (test rewrites). 643 tests pass on CPU after the refactor (was 756; the difference is parameterized hiding-equivalence tests + the parity harness, all deleted). Breaking change for any externally-composed checkpoint that was using control_dims > 0: those checkpoints are unloadable under this version. The token-exchange path has been the documented default since #8 and the only path that received the chat-template drops, so any in-flight build should already be on it. * Fix base-weight validator rejecting control_to_substitute_lut buffer (#8) The new switch buffer was failing compose-pipeline validation because buffer_keywords still listed the deleted legacy buffer names instead of the new one. Replace token_to_group_mask / adapter_hiding_matrix / all_hiding_group_token_ids with control_to_substitute_lut in arch.py and in the two test_granite4_mini parameter-allowlist assertions. * Remove dead hiding-constant report from compose output (#8) The report described safety margins for the finfo.min K-side hiding constant. Hiding is gone, so the section is meaningless. Drop the module, the call site in compose_report.py, and the package re-exports. * Update tests for removed legacy hiding fields (#8) — partial Replace control_dims / hiding_groups / hiding_policy / adapter_third_party references with adapter_substitute_token_ids in test fixtures, and drop TestControlTokenKVInvisibility (tested the deleted hiding mechanism). This is a partial sweep — vLLM workers, hf/test_single_switch_e2e.py, and shared/granite4_equivalence.py still need follow-up edits. * Strip remaining legacy hiding-field references from tests/docs (#8) - tests/hf/test_single_switch_e2e.py: drop CONTROL_DIMS_MODES axis; one parametrization on attention_multiplier only. Fixture returns a 3-tuple. - tests/vllm/_generation_equivalence_worker.py and _tp_integration_worker.py: remove control_dims/hiding_groups/hiding_policy/adapter_third_party from composer calls; pass adapter_substitute_token_ids instead. - tests/vllm/_single_switch_worker.py: mock_config uses projection_head_dim. - tests/vllm/test_generation_equivalence.py: docstring updated. - tests/shared/granite4_equivalence.py: rationale comments updated for token-exchange (no behavior change). - src/granite_switch/composer/compose_utils.py: docstring/comment cleanup. * Drop tensor.any() gate from switch token-exchange rewrite (#8) The vLLM decoder is wrapped in @support_torch_compile; Dynamo cannot trace data-dependent branching like ``if is_control.any()``. The gate broke engine init on GPU runs. Replace it with an unconditional torch.where in both backends — keeps HF and vLLM symmetric, costs one indexed gather + one elementwise select per forward, and makes the switch compile-safe. * Fix vLLM test runners after switch tuple-return + dead-class purge (#8) Three fixes uncovered by GPU run: 1. tests/vllm/_single_switch_worker.py: switch.forward now returns (adapter_indices, modified_input_ids); unpack and return only the indices. Worker was calling .cpu() on a tuple → every parametrized test in tests/vllm/test_single_switch.py failed at the same point. 2. tests/vllm/test_model_forward.py: drop the TestControlTokenKVInvisibility class stub. The inner class was deleted with the legacy hiding tests in 0ddaf0e, but the parametrized runner still referenced it. 3. tests/vllm/test_position_zero_nan.py: deleted. The inner _position_zero_nan_tests.py was removed (only existed for the legacy hiding path); the runner became orphan and pytest reported "file or directory not found" on every parametrized variant. The flash_api.cpp:697 "no kernel image" failures in test_model_forward are pre-existing GPU/FlashAttention environment issues, not branch bugs. * Document the LoRA substitute probe's data-independence assumption (#8) Per review feedback: tighten the probe's docstring to state the assumption it relies on — that the chat template emits a constant input_ids[0] regardless of message content, system-prompt presence, or generation-prompt flag — and call out that this is verified empirically for Granite 4.x (every realistic render shape produces <|start_of_role|>). Note what would need to change if a future base model's template breaks the assumption. No behavior change. * Remove residual position-correction references (#8) The position-correction code path was removed when token-exchange became the default. Drop the dead test class and the two stale parenthetical comments that still mentioned it. No behavioural change. * Add tests/vllm/test_token_exchange.py covering vLLM token-exchange path Mirrors tests/hf/test_token_exchange.py for the vLLM SingleSwitch backend. Closes the coverage gap raised by @antonpibm on PR #34: the token-exchange LUT and modified_input_ids rewrite were tested only on HF — vLLM had zero direct assertions on either. New file `tests/vllm/test_token_exchange.py`: - TestLUTMapping: query_lut command returns control_to_substitute_lut; asserts lut[ctrl_id] == sub_id for each adapter, lut[other] == -1 - TestInputRewrite: forward_with_modified command returns both adapter_indices and modified_input_ids; asserts non-control positions unchanged, control positions rewritten to substitute, multi-control sequence handles each independently, and adapter detection still fires on the original (pre-rewrite) input_ids Worker changes (`tests/vllm/_single_switch_worker.py`): - Mock config now populates adapter_token_ids + adapter_substitute_token_ids so SingleSwitch builds the LUT — production configs always have these. Existing tests are unaffected (they discard modified_input_ids). - New _run_with_modified() helper returns both forward outputs as lists - New "forward_with_modified" and "query_lut" commands wired into the request loop. The pre-existing "forward" command is unchanged. Substitute mapping in worker: control id (1000+i) → substitute id (i+1), matched by ADAPTER_SUBSTITUTE_TOKEN_IDS_LIST in the new test file. * Move LUT buffer to CUDA in worker setup Previous commit added adapter_substitute_token_ids to the worker's mock config to enable token-exchange tests. SingleSwitch.__init__ registers control_to_substitute_lut as a CPU buffer (no device specified), and the worker never explicitly calls switch.to(device) — the Q/K/V tensors are built on CUDA directly in forward(). On Ampere this caused the very first forward to raise: RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) …inside the new lut[input_ids] indexing operation. The fail-fast probe caught it cleanly and broke every existing test that uses the worker. Fix: after switch construction, move control_to_substitute_lut to the same CUDA device the worker uses for everything else. Test-only fix; production code path (SingleSwitch.__init__) is untouched.
* Update query rewrite example in hello_mellea tutorial Replace an example. * Fix broken Colab links after tutorial renaming Update notebook links to use new names without numeric prefixes: - 01_hello_mellea.ipynb → hello_mellea.ipynb - 03_01_govt_rag_pipeline_simple.ipynb → rag_full_pipeline.ipynb - 04_compose_granite_switch.ipynb → compose_granite_switch.ipynb - 05_alora_vs_lora_race.ipynb → alora_vs_lora_race.ipynb * Update GPU requirement from A100 to T4 in hello_mellea The 3B model runs fine on T4 GPUs, which are more accessible on Colab. Tested and verified on T4. * Rename files and update terminology across tutorials - Rename rag_full_pipeline.ipynb to rag_full_flow.ipynb - Rename bring_your_own_adapter.md to build_your_own_adapter.md - Rename mellea_bring_your_own_adapter.md to mellea_build_your_own_adapter.md - Rename run_pipeline to run_conversation_turn in rag_full_flow - Replace "adapter" with "adapter function" in user-facing text where it refers to the invocable capability (not LoRA weights or file names) - Update all cross-references to match new file names --------- Co-authored-by: Alon Freund <alonf@il.ibm.com> Co-authored-by: yairallouche <yair@il.ibm.com>
…otebooks (#64) - Updated hello_adapter.ipynb to specify T4 or better GPU requirement - Updated granite_switch_with_hf.ipynb to specify T4 or better GPU requirement - Consistent with hello_mellea.ipynb which was previously tested on T4 - Makes tutorials more accessible by clarifying minimum GPU requirements Co-authored-by: Alon Freund <alonf@il.ibm.com>
* Convert "pipeline" to "flow", rename rag_full_flow to rag_flow, drop unused display variants * Validate links * Align tutorial notebooks with template
Change mellea dependency from range constraint (>=0.1.0,<=0.6.0) to exact version pinning (==0.6.0) Co-authored-by: Alon Freund <alonf@il.ibm.com>
Updates requires-python from >=3.10,<3.14 to >=3.11,<3.14 to match the actual requirement of the mellea==0.6.0 dependency, which requires Python >=3.11. This prevents unsatisfiable dependency errors for Python 3.10 users and provides a clear error message upfront. Fixes #69 Co-authored-by: Alon Freund <alonf@il.ibm.com>
BitsAndBytes 4-bit quantization packs weights as uint8 with shape [total_elements//2, 1], which breaks the existing weight.shape-based dimension detection in SwitchedLoRALinear.__init__(). Fix: - Prefer input_size_per_partition / output_size_per_partition attributes (always correct, regardless of weight packing format) - Fall back to weight.shape only for non-parallel layers - Add dtype guard: if weight dtype is non-floating-point (uint8 for BnB), default to bfloat16 for LoRA buffer allocation Also adds vLLM quantization tests (BnB INT4 + FP8) that verify: - Base model weights are actually quantized - LoRA/aLoRA weights remain in full precision - Adapters activate correctly under quantization - LoRA dimensions are not corrupted by packed weight shapes Closes #16
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
BitsAndBytes 4-bit quantization packs weights as uint8 with shape [total_elements//2, 1], which breaks the existing weight.shape-based dimension detection in SwitchedLoRALinear.init().
Fix:
Also adds vLLM quantization tests (BnB INT4 + FP8) that verify: