Skip to content

Declarative weight conversion framework#523

Open
jlamypoirier wants to merge 10 commits into
mainfrom
jlp_declarative_weight_converters
Open

Declarative weight conversion framework#523
jlamypoirier wants to merge 10 commits into
mainfrom
jlp_declarative_weight_converters

Conversation

@jlamypoirier
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier commented May 20, 2026

Summary

Mirrors the post-#508 declarative config-side shape on the weight side. Every section converter now declares its weight mapping via _create_weight_converters (cached, override-by-key), the walker flattens into the existing list[WeightConverter] runtime, and tied embeddings move from per-call drop_on_export=tied plumbing into the walker-central OutputProjectionWeightConverter marker. Item 1 of the post-#508 deferred list.

What changed

  • Framework in external.py — new primitives NestedWeightConverter, BlockSequenceWeightConverter (Fixed/Pattern + optional per-position dispatch_registry), DispatchWeightConverter (single-attribute type dispatch), TypedDictWeightConverter (per-key dispatch on dict[str, Config]), LinearWeightConverter (bundles .weight/.bias with per-section bias_fn), OutputProjectionWeightConverter (walker-dropped when root_config.tied_embedding_weight is set). Generic transforms KeyValueWeightConverter / TransposeSplitWeightConverter (was MLPLayer2Converter) / PatchEmbeddingWeightConverter relocate here from llama.py / llava.py.
  • All formats migrated — llama, mistral, qwen2, mixtral, mtp_llama, gemma4, diffusion_dream, diffusion_llama, apriel, apriel2, llava, apriel2-multimodal.
  • Cleanupget_parameter_converter, get_weight_and_bias_converters, MLPLayer2Converter alias, and the drop_on_export parameter plumbing are removed. effective_bias stays as a published helper for Apriel/Apriel2 CustomConfigConverter export_fns.

Test plan

Follow-up

  • Extending the static converter test walker to weight coverage (every state-dict key emitted by a default-constructed model is consumed by some leaf WeightConverter) is worth doing but not on the critical path — deferred.

🤖 Generated with Claude Code

jlamypoirier and others added 4 commits May 20, 2026 12:52
Mirrors the post-#508 config-side shape on the weight side. Adds
``_create_weight_converters`` + walker on ``ConfigSectionConverter`` with new
primitives (Nested/BlockSequence/Linear/OutputProjection) in ``external.py``.
Relocates ``KeyValueWeightConverter``/``TransposeSplitWeightConverter``
(formerly ``MLPLayer2Converter``) so the layers/multimodal converters can
import them from the engine instead of llama.py.

Migrates llama/mistral/qwen2/mixtral/mtp_llama to the new shape. Tied
embeddings move from per-call ``drop_on_export=tied`` plumbing to the
walker-central ``OutputProjectionWeightConverter`` marker. Legacy
``get_converters``/``get_parameter_converter``/``get_weight_and_bias_converters``
helpers stay in llama.py as shims for the not-yet-migrated converters
(apriel/apriel2/gemma4/multimodal); cleanup commit removes them.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Adds ``DispatchWeightConverter`` (runtime-type dispatch on a single attribute) and
``TypedDictWeightConverter`` (per-key dispatch on a ``dict[str, Config]`` attribute) to the
framework. Apriel's hybrid block-sequence uses ``BlockSequenceWeightConverter``'s
``dispatch_registry``; Apriel2 uses both new primitives — ``DispatchWeightConverter`` for the
block mixer + normalization dispatch, ``TypedDictWeightConverter`` for the StochasticMixer
sub-mixer fan-out.

The Apriel2 Fixed/Pattern decoder section converters now contribute no weights of their own;
the block fan-out runs once at the base-model level via ``BlockSequenceWeightConverter``, which
already handles both shapes through its ``FixedBlockSequenceConfig`` / ``PatternBlockSequenceConfig``
dispatch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Last of the model-side migrations: the multimodal Llava (Pixtral vision + Mistral text) and the
multimodal Apriel2 (Apriel2 vision + Apriel2 text base) handlers now declare their weight
conversions via ``_create_weight_converters``. ``PatchEmbeddingWeightConverter`` is now imported
from the engine (relocated earlier); the local copies in ``llava.py`` are removed.

Gemma4 keeps its imperative ``get_converters`` and continues to work via the
``ConfigSectionConverter.get_converters`` shim — its helper classes don't inherit
``ConfigSectionConverter`` so they don't get a free declarative migration. Revisit in cleanup or
a follow-up.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Removes the post-migration deadweight:

* ``ConfigSectionConverter.get_converters`` (section-shape shim) — every consumer now calls
  ``emit_weight_converters`` directly.
* ``get_parameter_converter``, ``get_weight_and_bias_converters``, ``MLPLayer2Converter`` alias in
  llama.py — no remaining callers.
* ``drop_on_export`` parameter plumbing throughout gemma4 — the only legitimate use case
  (head tied embeddings) is handled by ``OutputProjectionWeightConverter`` at the walker level.

Gemma4 gains a local ``_linear_converters`` helper that builds ``.weight`` and (optional)
``.bias`` ``WeightConverter`` instances directly — Gemma4's helper classes don't inherit
``ConfigSectionConverter`` so the ``LinearWeightConverter`` declarative primitive doesn't apply.

``effective_bias`` stays in llama.py as a published helper — still used by Apriel/Apriel2
config-side ``CustomConfigConverter`` export_fns and the matching ``LinearWeightConverter.bias_fn``
lambdas.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@jlamypoirier
Copy link
Copy Markdown
Collaborator Author

jlamypoirier commented May 20, 2026

CI status [fixed as of #524]

CI shows 60 failures, but these are pre-existing and identical to main's. Confirmed by diffing the failed-test names: cmp /tmp/main_failed /tmp/pr_failed → identical.

All failures are TypeError: create_causal_mask() got an unexpected keyword argument 'cache_position' from fast_llm_external_models/apriel2/modeling_apriel2.py:940 — a transformers-version mismatch introduced by #521 (PyTorch base image bump to 26.02-py3 brought a newer transformers where create_causal_mask dropped the cache_position kwarg). None of these files are touched by this PR.

Local + cluster results on the PR's HEAD:

jlamypoirier and others added 6 commits May 21, 2026 11:08
* Drop ``AprielBlockConverter.get_converters`` — calls a now-nonexistent ``.get_converters`` on
  the block-converter registry values and is unreachable in practice (dispatch goes through
  ``BlockSequenceWeightConverter(dispatch_registry=...)``).
* Drop the unused ``block_converter_class`` ClassVar from Apriel/Mistral/Qwen2/Mixtral head
  converters — only MTP-Llama's head reads it (kept on ``LlamaHeadConverter``).
* Drop the ``exported_config`` parameter throughout: no surviving ``get_converters`` override
  reads it, and the ``__init__`` ``_export_config(model.config)`` precompute it powered is
  gone. Tied-embedding handling lives on ``OutputProjectionWeightConverter``.
* Fold ``_FixedBlockFanoutWeightConverter`` into ``BlockSequenceWeightConverter`` via a
  ``config_attr=""`` sentinel for "section IS the block sequence" — kills the cross-package
  private import from ``llama.py`` into ``multimodal/apriel2.py``.
* ``LinearWeightConverter.bias_fn`` and ``OutputProjectionWeightConverter._emit`` use direct
  attribute access instead of ``getattr(..., default)`` — misuse now surfaces as ``AttributeError``
  rather than silently falling back to ``False``.
* Tighten ``BlockSequenceWeightConverter``'s assertion to XOR — passing both
  ``block_converter_class`` and ``dispatch_registry`` no longer silently ignores the former.
* Extract ``_join_prefix(parent, own)`` helper for the empty-handling rule shared across
  Nested/BlockSequence/Dispatch/TypedDict ``_emit`` methods.
* Apriel2 base + multimodal aggregators get a ``block_converter_class`` ClassVar (matches
  ``LlamaBaseModelConverter``) instead of hardcoding ``Apriel2BlockConverter`` inline.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Items 1-7 from the second /review-coarse pass:

1. Delete unused IgnoreImportWeightConverter / IgnoreExportWeightConverter
   (last callers were the removed drop_on_export plumbing).
2. Migrate Gemma4 to declarative weight converters — Gemma4Attention/MLP/
   MoEMLP/HybridMoEMLP/Block/Decoder now inherit ConfigSectionConverter
   and define _create_weight_converters. The Gemma4-specific transforms
   (shared-K/V branching, mlp-type dispatch with divergent HF prefixes,
   conditional norm_2, two-level hybrid-MoE norm descent) live as small
   private WeightConverter subclasses next to the existing MoE layer
   converters. Config side stays imperative under CustomConfigConverter
   at the aggregator (Gemma4 sliding/full block divergence prevents a
   uniform per-block declarative shape); each helper carries a blanket
   IgnoredConfigConverter to silence the static walker.
3. Add optional=True to NestedWeightConverter and fold Apriel2 multimodal's
   vision_encoder back into _create_weight_converters (skip when None).
4. Fold Llava head into LlavaBaseModelConverter._create_weight_converters
   (NestedWeightConverter with empty hf_prefix; LlavaHead's leaf names are
   already absolute).
5. Move block_converter_class ClassVar from LlamaHeadConverter to its sole
   reader MTPLlamaHeadConverter.
6. Replace BlockSequenceWeightConverter's config_attr="" sentinel with an
   explicit read_self=True flag (2 callers updated).
7. Delete the four pass-only HeadConverter subclasses (Mistral, Mixtral,
   Qwen2, Apriel); the head_converter_class ClassVar inherits from
   LlamaBaseModelConverter, and LlavaHeadConverter rebases on
   LlamaHeadConverter directly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Items 1-4 from the third /review-coarse pass:

1. Split ``BlockSequenceWeightConverter`` into three flat primitives — single-class
   (``BlockSequenceWeightConverter``), per-position dispatch
   (``DispatchBlockSequenceWeightConverter``), and section-IS-the-block-sequence
   (``SelfBlockSequenceWeightConverter``). Drops the dual XOR ``Assert.custom``
   switchboard. Shared list-materialization extracted to ``_expand_block_sequence``.

2. Generalize two framework primitives so two of Gemma4's four private one-offs
   fold in:
   * ``DispatchWeightConverter`` gains ``hf_prefix_overrides`` for per-branch HF
     paths (Gemma4's block.mlp dispatch where dense lands under ``mlp.<...>`` and
     hybrid MoE flat-merges into the block root).
   * ``NestedWeightConverter.config_attr`` accepts tuple/dotted paths for chained
     ``getattr`` (Gemma4's hybrid-MoE inner norms via ``("dense", "pre_norm")``).
   Rename ``_join_prefix`` and ``_prepend_prefix`` to drop the underscore — now
   public utilities used by Gemma4's remaining two one-offs.

3. Lift the one-line ``cls.emit_weight_converters(config, "", "")`` passthrough
   into ``HuggingFaceBaseModelConverter.get_converters`` as a concrete default.
   Apriel2 (text), Apriel2 multimodal, and Llava lose their overrides.
   Apriel2BaseModelConverter now multi-inherits ``HuggingFaceBaseModelConverter``
   so it picks up the default. Llama, Gemma4, MTP-Llama keep their overrides —
   they splice ``head_converter_class.get_converters(config)`` separately because
   the head needs the full ``GPTBaseModelConfig`` (MTP-Llama reads
   ``config.decoder.last_block_config`` for per-prediction-head fan-out).

4. ``AprielBlockConverter`` docstring: ``get_converters`` was removed in the
   prior cleanup pass; update the docstring to describe the class as a registry
   holder consumed by ``ListDispatchConfigConverter`` (config side) and
   ``DispatchBlockSequenceWeightConverter`` (weight side).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Six Gemma4 section converters owned their config side only nominally —
the actual conversion lives on the aggregator's CustomConfigConverter
(fast_llm_recurses=True). Each declared an empty-path IgnoredConfigConverter
blanket-claim solely to silence the static architecture-coverage walker.

Replace the boilerplate with an explicit weight_only ClassVar on
ConfigSectionConverter that short-circuits both _create_config_converters
(empty default) and check_architecture_coverage.
- Delete dead helpers: get_apriel2_decoder_converter and two unreferenced
  _get_weight_converters classmethods (the latter also broken — called
  get_converters with a now-removed two-arg signature).

- Inline the now-unconditional ConfigSectionConverter coverage call in
  HuggingfaceStateDictCheckpointHandler._check_hf_coverage; every concrete
  base_model_converter_class is one.

- Drop the _effective_bias import alias in apriel2 (no name conflict).

- Trim docstrings that referenced the previous (removed) implementation:
  OutputProjectionWeightConverter / LinearWeightConverter / LlavaHeadConverter.

- Accept bias_fn=True/False bool literals on LinearWeightConverter; replaces
  ~10 `lambda c: False` callsites including all `no_bias = lambda c: False`
  named bindings.

- Hoist orphan trailing comments on Apriel2Fixed/PatternDecoderConverter
  into class docstrings.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant