From 56d9f0d0b815e18a21efa72e8ca2cf19a7953cac Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 May 2026 12:52:40 -0400 Subject: [PATCH 01/12] Introduce weight-side declarative conversion framework Mirrors the post-#508 config-side shape on the weight side. Adds ``_create_weight_converters`` + walker on ``ConfigSectionConverter`` with new primitives (Nested/BlockSequence/Linear/OutputProjection) in ``external.py``. Relocates ``KeyValueWeightConverter``/``TransposeSplitWeightConverter`` (formerly ``MLPLayer2Converter``) so the layers/multimodal converters can import them from the engine instead of llama.py. Migrates llama/mistral/qwen2/mixtral/mtp_llama to the new shape. Tied embeddings move from per-call ``drop_on_export=tied`` plumbing to the walker-central ``OutputProjectionWeightConverter`` marker. Legacy ``get_converters``/``get_parameter_converter``/``get_weight_and_bias_converters`` helpers stay in llama.py as shims for the not-yet-migrated converters (apriel/apriel2/gemma4/multimodal); cleanup commit removes them. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 357 +++++++++++++++++++- fast_llm/models/gpt/conversion/llama.py | 285 ++++++---------- fast_llm/models/gpt/conversion/mixtral.py | 45 +-- fast_llm/models/gpt/conversion/mtp_llama.py | 54 +-- fast_llm/models/gpt/conversion/qwen2.py | 43 +-- 5 files changed, 518 insertions(+), 266 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index d78e6b405..20e1863a5 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -543,17 +543,18 @@ def _consumed_hf_paths(self) -> frozenset[tuple[str, ...]]: class ConfigSectionConverter(abc.ABC): """Base class for converting one Fast-LLM ``Config`` class ↔ one HF dict subtree. - Subclasses declare the conversion via ``_create_config_converters``. Format-specific cross-field - invariants go on the ``_validate_export`` hook. The weight side is imperative — concrete subclasses - provide a ``get_converters`` classmethod that emits :class:`WeightConverter` instances. + Subclasses declare the conversion via ``_create_config_converters`` (config side) and + ``_create_weight_converters`` (weight side). Format-specific cross-field invariants go on the + ``_validate_export`` hook. Subclasses that participate in :class:`DispatchConfigConverter` set ``hf_type_name`` to the discriminator value used by the HF format (e.g. ``"attention"``, ``"mamba"``). .. warning:: - :meth:`_create_config_converters` is ``@functools.cache``\\ d on the base class. Subclasses that override - it must return a *fresh* dict (idiomatically ``{**super()._create_config_converters(), ...}``); mutating - the parent's returned dict in place would corrupt the cache entry for every subsequent caller. + Both ``_create_config_converters`` and ``_create_weight_converters`` are ``@functools.cache``\\ d on + the base class. Subclasses that override them must return a *fresh* dict (idiomatically + ``{**super()._create_..._converters(), ...}``); mutating the parent's returned dict in place would + corrupt the cache entry for every subsequent caller. """ fast_llm_config_class: typing.ClassVar[type[Config]] @@ -570,6 +571,18 @@ def _create_config_converters(cls) -> dict[str, ConfigConverter]: """ raise NotImplementedError + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, "WeightConverter"]: + """Return weight-conversion declarations keyed by stable string name. + + Same shape and caching rules as :meth:`_create_config_converters`. Section-relative names; the + walker (:meth:`emit_weight_converters`) prepends the section's full ``(fast_llm_prefix, hf_prefix)`` + pair as it descends. Defaults to no declarations — sections that don't own any weights leave this + unoverridden. + """ + return {} + @classmethod def _validate_export(cls, config: Config) -> None: """Hook for format-specific export-time validation. Default no-op. @@ -606,6 +619,50 @@ def import_config(cls, hf_dict: dict) -> dict: out = {"type": fast_llm_type, **out} return out + @classmethod + def get_converters( + cls, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + drop_on_export: bool = False, + ) -> list["WeightConverter"]: + """Imperative-shape entry point — delegates to the declarative walker. + + Section converters that haven't migrated override this with a custom body; migrated sections leave + it inherited. The ``drop_on_export`` parameter is accepted for signature compatibility with the + pre-migration shape but is unused — the walker handles tied embeddings via + :class:`OutputProjectionWeightConverter`. Once every consumer is migrated this shim and the + parameter are removed. + """ + return cls.emit_weight_converters(config, fast_llm_prefix, hf_prefix) + + @classmethod + def emit_weight_converters( + cls, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config | None = None, + ) -> list["WeightConverter"]: + """Walk this section's weight declarations against ``config`` into a flat list of fully-qualified + :class:`WeightConverter` instances. + + Each declaration in :meth:`_create_weight_converters` returns one or more ``WeightConverter`` leaves via + its :meth:`WeightConverter._emit` hook. Structural primitives (Nested, BlockSequence) recurse into + sub-section converters; leaves return a single prefixed copy of themselves. ``root_config`` carries the + top-level config through the recursion for primitives whose behaviour depends on it (e.g. + :class:`OutputProjectionWeightConverter` consults ``root_config.tied_embedding_weight``); the walker + seeds it from ``config`` on the outermost call. + """ + if root_config is None: + root_config = config + out: list["WeightConverter"] = [] + for declaration in cls._create_weight_converters().values(): + out.extend(declaration._emit(config, fast_llm_prefix, hf_prefix, root_config=root_config)) + return out + @classmethod @functools.cache def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: @@ -715,7 +772,22 @@ def check_architecture_coverage(cls, config: Config) -> None: ) +def _prepend_prefix(prefix: str, names: tuple[str, ...]) -> tuple[str, ...]: + """Prepend ``prefix`` to each name. Empty ``prefix`` is a no-op; empty ``names`` (drop side) stays empty.""" + if not prefix: + return names + return tuple(f"{prefix}.{name}" for name in names) + + class WeightConverter: + """Leaf weight-conversion declaration / emitted instance. + + As a declaration in :meth:`ConfigSectionConverter._create_weight_converters`, the ``fast_llm_name`` and + ``export_name`` are *section-relative*; the walker constructs a fully-qualified emitted copy by prepending + the section prefixes via :meth:`_emit`. Subclasses that need extra construction context (e.g. capturing + a sub-config for use inside ``export_weight``/``import_weight``) override :meth:`_emit` accordingly. + """ + def __init__( self, fast_llm_name: str | tuple[str, ...], @@ -736,6 +808,27 @@ def import_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: return weight + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list["WeightConverter"]: + """Return a fully-qualified emitted copy of this leaf. + + Subclasses that capture extra construction state (e.g. :class:`KeyValueWeightConverter` stashing + an :class:`AttentionConfig`) override this hook to pass that state into the emitted copy. + """ + return [ + type(self)( + _prepend_prefix(fast_llm_prefix, self.fast_llm_name), + _prepend_prefix(hf_prefix, self.export_name), + config, + ) + ] + class IgnoreImportWeightConverter(WeightConverter): def __post_init__(self): @@ -786,6 +879,258 @@ def import_weight( return (torch.cat([weight_[:] for weight_ in weight]),) +class TransposeSplitWeightConverter(WeightConverter): + """Split a merged weight across the last dim with an additional transpose. + + Equivalent to :class:`SplitWeightConverter` for non-gated MLPs (trivial split) and for 1-D biases + (trivial transpose); the real behaviour kicks in for the down-projection of a gated MLP where HF + stores the weight in transposed orientation. + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) + + +class KeyValueWeightConverter(WeightConverter): + """Pack/unpack a fused key-value tensor across the two HF names. + + Fast-LLM packs key/value as a single concatenated tensor; HF stores them as two siblings + (``k_proj`` / ``v_proj``). Identity for bias because biases are concatenated the same way. + """ + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + (key_value,) = weight + key, value = key_value[:].chunk(2) + return key, value + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + key, value = weight + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +class PatchEmbeddingWeightConverter(WeightConverter): + """Reshape a vision patch-embedding weight from Fast-LLM's flat ``(out, channels*h*w)`` shape to HF's + ``(out, channels, h, w)`` (and back). Requires a config exposing ``input_channels``/``patch_height``/ + ``patch_width`` via the constructor's ``config`` argument.""" + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return tuple( + weight_[:].view( + *weight_[:].shape[:-1], + self._config.input_channels, + self._config.patch_height, + self._config.patch_width, + ) + for weight_ in weight + ) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return tuple( + weight_[:].view( + *weight_[:].shape[:-3], + self._config.input_channels * self._config.patch_height * self._config.patch_width, + ) + for weight_ in weight + ) + + +class OutputProjectionWeightConverter(WeightConverter): + """Marker for the LM-head output projection (typically ``head.output_weights`` ↔ ``lm_head.weight``). + + When the root config has ``tied_embedding_weight=True``, the walker drops this declaration entirely — + HF stores tied embeddings as just ``embed_tokens.weight`` with no separate ``lm_head.weight``. Replaces + the per-call ``drop_on_export=exported_config["tie_word_embeddings"]`` plumbing. + """ + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + if getattr(root_config, "tied_embedding_weight", False): + return [] + return super()._emit(config, fast_llm_prefix, hf_prefix, root_config=root_config) + + +class NestedWeightConverter(WeightConverter): + """Recurse into a sub-section's weight declarations. + + The sub-section's config is read from ``getattr(config, config_attr)`` (defaults to ``fast_llm_prefix`` + when the state-dict prefix and the parent's attribute name agree). The walker descends into + ``sub_converter_class._create_weight_converters()`` with extended prefixes. Mirrors + :class:`NestedConfigConverter` on the config side. + + The separate ``config_attr`` covers cases like a block's single ``normalization`` config feeding two + state-dict prefixes (``norm_1`` / ``norm_2``). + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + sub_converter_class: type["ConfigSectionConverter"], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._sub_converter_class = sub_converter_class + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + sub_config = getattr(config, self._config_attr) + return self._sub_converter_class.emit_weight_converters( + sub_config, + f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix, + f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix), + root_config=root_config, + ) + + +class BlockSequenceWeightConverter(WeightConverter): + """Fan out a per-block sub-section across every position in a block sequence. + + The sub-section's converter class is resolved per-position from ``block_converter_class``: by default, + the same class for every position; when ``dispatch_registry`` is provided, the per-position config is + matched against the registry keys (Apriel's hybrid-block dispatch — different mixer types per layer). + + Handles both ``FixedBlockSequenceConfig`` (single repeated block) and ``PatternBlockSequenceConfig`` + (per-position blocks indexed via ``decoder.expanded_pattern``). + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + block_converter_class: type["ConfigSectionConverter"], + *, + config_attr: str | None = None, + dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]] | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._block_converter_class = block_converter_class + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + self._dispatch_registry = dispatch_registry + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + # Lazy import to keep external.py free of layers/ dependencies. + from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig + + block_sequence = getattr(config, self._config_attr) + if isinstance(block_sequence, FixedBlockSequenceConfig): + per_position_blocks = [block_sequence.block] * block_sequence.num_blocks + elif isinstance(block_sequence, PatternBlockSequenceConfig): + per_position_blocks = [block_sequence.blocks[name] for name in block_sequence.expanded_pattern] + else: + raise NotImplementedError(type(block_sequence).__name__) + + fast_llm_root = f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix + hf_root = f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix) + out: list[WeightConverter] = [] + for index, block in enumerate(per_position_blocks): + block_class = ( + self._dispatch_registry[type(block.mixer)] + if self._dispatch_registry is not None + else self._block_converter_class + ) + out += block_class.emit_weight_converters( + block, + f"{fast_llm_root}.{index}", + f"{hf_root}.{index}", + root_config=root_config, + ) + return out + + +class LinearWeightConverter(WeightConverter): + """Bundle a linear layer's ``.weight`` and (conditionally) ``.bias`` declarations into one entry. + + Bias presence is resolved at emission time from the live section config: ``bias_fn(config)`` returns + a bool. The default reads ``config.add_linear_biases`` — the shared flag every Llama-style attention/MLP + section carries. Sections with per-layer overrides (e.g. Apriel Mamba's ``dt_layer`` / ``convolution_layer``) + pass a lambda that resolves the override. + + ``transform`` selects the leaf class for both weight and bias: :class:`WeightConverter` for plain rename + (the default), :class:`SplitWeightConverter` for fused → split, :class:`KeyValueWeightConverter` for + fused KV → separate K/V, :class:`TransposeSplitWeightConverter` for MLP down-projection. + + Replaces the imperative ``get_weight_and_bias_converters`` / ``effective_bias`` helpers. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str | tuple[str, ...] | typing.Callable[[Config], str | tuple[str, ...]], + *, + transform: type[WeightConverter] = WeightConverter, + bias_fn: typing.Callable[[Config], bool] = lambda c: getattr(c, "add_linear_biases", False), + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + # ``hf_prefix`` may be a callable (e.g. Mixtral's ``experts.{i}.w1``-style fan-out where the + # expert count comes from the live config). + self._hf_prefix = hf_prefix + self._transform = transform + self._bias_fn = bias_fn + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + resolved = self._hf_prefix(config) if callable(self._hf_prefix) else self._hf_prefix + hf_prefixes: tuple[str, ...] = (resolved,) if isinstance(resolved, str) else tuple(resolved) + weight_fast_llm = _prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.weight",)) + weight_hf = _prepend_prefix(hf_prefix, tuple(f"{p}.weight" for p in hf_prefixes)) + emitted: list[WeightConverter] = [self._transform(weight_fast_llm, weight_hf, config)] + if self._bias_fn(config): + bias_fast_llm = _prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.bias",)) + bias_hf = _prepend_prefix(hf_prefix, tuple(f"{p}.bias" for p in hf_prefixes)) + emitted.append(self._transform(bias_fast_llm, bias_hf, config)) + return emitted + + class ExternalStateDictCheckpointHandler(StateDictCheckpointHandler): _model_class: typing.ClassVar[FastLLMModelConfig] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 425a198f6..29373ae74 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -1,23 +1,27 @@ import dataclasses +import functools import logging import typing -import torch import transformers from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantImportConfigConverter, CustomConfigConverter, DefaultConfigConverter, IgnoredConfigConverter, - IgnoreExportWeightConverter, - IgnoreImportWeightConverter, + KeyValueWeightConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, + OutputProjectionWeightConverter, RenameConfigConverter, SplitWeightConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -32,14 +36,12 @@ from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import ( - LanguageModelConfig, LanguageModelEmbeddingsConfig, LanguageModelHeadConfig, ) from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.model import GPTModel -from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div _TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) @@ -59,10 +61,17 @@ def effective_bias(layer_config: AffineLinearConfig, default: bool) -> bool: # ============================================================ -# Weight converters (imperative) +# Legacy imperative weight-conversion helpers — kept for converters that haven't migrated to the +# declarative shape yet. Deleted in the final cleanup commit once every consumer is on +# ``_create_weight_converters``. # ============================================================ +# Re-export under the legacy name so callers that import ``MLPLayer2Converter`` from this module keep +# working during migration. +MLPLayer2Converter = TransposeSplitWeightConverter + + def get_parameter_converter( fast_llm_name: str | tuple[str, ...], hf_name: str | tuple[str, ...], @@ -71,6 +80,8 @@ def get_parameter_converter( drop_on_export: bool = False, drop_on_import: bool = False, ) -> WeightConverter: + from fast_llm.engine.checkpoint.external import IgnoreExportWeightConverter, IgnoreImportWeightConverter + if isinstance(fast_llm_name, str): fast_llm_name = (fast_llm_name,) if isinstance(hf_name, str): @@ -123,42 +134,6 @@ def get_weight_and_bias_converters( return converters -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - return key, value - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - key_value = torch.cat([key[:], value[:]]) - return (key_value,) - - # ============================================================ # Config converters (declarative) # ============================================================ @@ -256,19 +231,9 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: RMSNormalizationConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return get_weight_and_bias_converters( - fast_llm_prefix, - () if drop_on_export else hf_prefix, - False, - IgnoreExportWeightConverter if drop_on_export else WeightConverter, - ) + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return {"weight": WeightConverter("weight", "weight")} class LlamaMLPConverter(ConfigSectionConverter): @@ -303,29 +268,12 @@ def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, - config: MLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), - config.add_linear_biases, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - config.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter("layer_1", ("gate_proj", "up_proj"), transform=SplitWeightConverter), + "layer_2": LinearWeightConverter("layer_2", "down_proj", transform=TransposeSplitWeightConverter), + } class LlamaAttentionConverter(ConfigSectionConverter): @@ -385,35 +333,13 @@ def _validate_export(cls, config: AttentionConfig) -> None: Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - config.add_linear_biases, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "query": LinearWeightConverter("query", "q_proj"), + "key_value": LinearWeightConverter("key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter), + "dense": LinearWeightConverter("dense", "o_proj"), + } class LlamaBlockConverter(ConfigSectionConverter): @@ -448,35 +374,18 @@ def _validate_export(cls, config: DecoderBlockConfig) -> None: Assert.custom(lambda v: not v, config.output_scale.enabled) @classmethod - def get_converters( - cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False - ) -> list[WeightConverter]: - return [ - *cls.mixer_converter_class.get_converters( - config.mixer, - f"{fast_llm_prefix}.mixer", - f"{hf_prefix}.{cls.hf_mixer_name}", - drop_on_export, - ), - *cls.mlp_converter_class.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - f"{hf_prefix}.{cls.hf_mlp_name}", - drop_on_export, - ), - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.{cls.hf_norm_1_name}", - drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": NestedWeightConverter("mixer", cls.hf_mixer_name, cls.mixer_converter_class), + "mlp": NestedWeightConverter("mlp", cls.hf_mlp_name, cls.mlp_converter_class), + "norm_1": NestedWeightConverter( + "norm_1", cls.hf_norm_1_name, cls.normalization_converter_class, config_attr="normalization" ), - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.{cls.hf_norm_2_name}", - drop_on_export, + "norm_2": NestedWeightConverter( + "norm_2", cls.hf_norm_2_name, cls.normalization_converter_class, config_attr="normalization" ), - ] + } def _llama_decoder_export( @@ -521,22 +430,37 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: FixedBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``BlockSequenceWeightConverter`` works on parent-level (decoder lives one level up); here the + # section IS the block sequence, so emit a flat fan-out keyed by block index. + # Used only by Pixtral's vision encoder (the standard text formats inline the block dispatch at + # the base-model converter instead). + return {"blocks": _FixedBlockFanoutWeightConverter(cls.block_converter_class)} + + +class _FixedBlockFanoutWeightConverter(WeightConverter): + """Emit one set of block-sub-converter declarations per position of a ``FixedBlockSequenceConfig``. + + Lives here because ``LlamaDecoderConverter``'s section config *is* the block sequence — there is no + parent attribute to read via the generic :class:`BlockSequenceWeightConverter` shape. + """ + + def __init__(self, block_converter_class: type[ConfigSectionConverter]): + super().__init__((), ()) + self._block_converter_class = block_converter_class + + def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): + Assert.is_(type(config), FixedBlockSequenceConfig) + out: list[WeightConverter] = [] + for index in range(config.num_blocks): + out += self._block_converter_class.emit_weight_converters( config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, + f"{fast_llm_prefix}.{index}" if fast_llm_prefix else str(index), + f"{hf_prefix}.{index}" if hf_prefix else str(index), + root_config=root_config, ) - return converters + return out class LlamaEmbeddingsConverter(ConfigSectionConverter): @@ -562,10 +486,9 @@ def _validate_export(cls, config: LanguageModelEmbeddingsConfig) -> None: Assert.incl(config.position_embeddings.enabled, (None, False)) @classmethod - def get_converters( - cls, config: LanguageModelEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: - return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return {"word_embeddings": WeightConverter("word_embeddings_weight", "embed_tokens.weight")} class LlamaHeadConverter(ConfigSectionConverter): @@ -590,25 +513,30 @@ def _create_config_converters(cls) -> dict: "final_logit_softcap": ConstantImportConfigConverter(("final_logit_softcap",), None), } + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``final_norm`` reads the head's own ``normalization`` config; ``output_weights`` is the marker the + # walker drops automatically when the root config has ``tied_embedding_weight=True``. + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.norm", cls.normalization_converter_class, config_attr="normalization" + ), + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } + @classmethod def get_converters( cls, - config: LanguageModelConfig, + config: GPTBaseModelConfig, exported_config: dict, ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.head.normalization, - f"head.final_norm", - f"model.norm", - ), - get_parameter_converter( - f"head.output_weights", - "lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], - drop_on_export=exported_config["tie_word_embeddings"], - ), - ] + """Aggregator-shape shim: non-migrated base-model converters pass the full + :class:`GPTBaseModelConfig` plus the exported HF dict. Translates to the declarative walker — + the tied-embedding handling now lives on :class:`OutputProjectionWeightConverter` and reads + ``root_config.tied_embedding_weight`` directly, so ``exported_config`` is unused. + """ + return cls.emit_weight_converters(config.head, "head", "", root_config=config) class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): @@ -666,22 +594,21 @@ def _decoder_import(hf_dict: dict) -> dict: def _validate_export(cls, config: GPTBaseModelConfig) -> None: assert_no_peft(config) + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``head`` is added at the aggregator level (in :meth:`get_converters`) because the head + # converter's signature takes the root config + the exported HF dict so subclasses can + # extend it (e.g. MTP-Llama fans out per-prediction-head blocks and norms). + return { + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": BlockSequenceWeightConverter("decoder", "model.layers", cls.block_converter_class), + } + @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - decoder_config = config.decoder - block_config = ( - decoder_config.block - if isinstance(decoder_config, FixedBlockSequenceConfig) - else next(iter(decoder_config.blocks.values())) - ) - block_converters: list[WeightConverter] = [] - for block_index in range(decoder_config.num_blocks): - block_converters += cls.block_converter_class.get_converters( - block_config, f"decoder.{block_index}", f"model.layers.{block_index}" - ) return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *block_converters, + *cls.emit_weight_converters(config, "", ""), *cls.head_converter_class.get_converters(config, exported_config), ] diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 0403413a9..800d0973a 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -1,16 +1,19 @@ +import functools import typing from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantImportConfigConverter, IgnoredConfigConverter, + LinearWeightConverter, RenameConfigConverter, SplitWeightConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.layers.decoder.mlp.config import MoEMLPConfig, RoutingType from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat -from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -58,35 +61,21 @@ def _validate_export(cls, config: MoEMLPConfig) -> None: Assert.custom(lambda v: not v, config.router_per_expert_scale.enabled) @classmethod - def get_converters( - cls, - config: MoEMLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.router", - f"{hf_prefix}.gate", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - tuple(f"{hf_prefix}.experts.{i}.{w}" for i in range(config.experts) for w in ("w1", "w3")), - False, - SplitWeightConverter, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "router": LinearWeightConverter("router", "gate"), + "layer_1": LinearWeightConverter( + "layer_1", + lambda c: tuple(f"experts.{i}.{w}" for i in range(c.experts) for w in ("w1", "w3")), + transform=SplitWeightConverter, ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - tuple(f"{hf_prefix}.experts.{i}.w2" for i in range(config.experts)), - False, - MLPLayer2Converter, - drop_on_export=drop_on_export, + "layer_2": LinearWeightConverter( + "layer_2", + lambda c: tuple(f"experts.{i}.w2" for i in range(c.experts)), + transform=TransposeSplitWeightConverter, ), - ] + } class MixtralBlockConverter(MistralBlockConverter): diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 6f6d9e88a..9c5c90c7e 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -1,17 +1,21 @@ +import functools import typing from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import RenameConfigConverter, WeightConverter -from fast_llm.layers.language_model.config import LanguageModelConfig -from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.engine.checkpoint.external import ( + NestedWeightConverter, + OutputProjectionWeightConverter, + RenameConfigConverter, + WeightConverter, +) +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, - get_parameter_converter, ) from fast_llm.utils import safe_merge_dicts @@ -25,37 +29,41 @@ def _create_config_converters(cls) -> dict: "prediction_heads": RenameConfigConverter(("prediction_heads",), ("prediction_heads",)), } + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # MTP-Llama places the first prediction head's final norm under ``model.mtp_norms.0`` instead + # of the standard ``model.norm``; the additional MTP blocks/norms come from the imperative + # ``get_converters`` override below since their count depends on ``head.prediction_heads``. + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.mtp_norms.0", cls.normalization_converter_class, config_attr="normalization" + ), + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } + @classmethod def get_converters( cls, - config: LanguageModelConfig, + config: GPTBaseModelConfig, exported_config: dict, ) -> list[WeightConverter]: - # MTP-Llama uses ``model.mtp_norms.0`` for the first prediction head's final norm - # instead of the standard ``model.norm``. - converters = [ - *cls.normalization_converter_class.get_converters( - config.head.normalization, - "head.final_norm", - "model.mtp_norms.0", - ), - get_parameter_converter( - "head.output_weights", - "lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], - drop_on_export=exported_config["tie_word_embeddings"], - ), - ] + converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config)) + # Append the MTP fan-out: one block + one norm per extra prediction head. ``block_converter_class`` + # comes from the parent ``LlamaHeadConverter`` ClassVar — the MTP block shape matches the main + # decoder block. for prediction_distance in range(2, config.head.prediction_heads + 1): - converters += cls.block_converter_class.get_converters( + converters += cls.block_converter_class.emit_weight_converters( config.decoder.last_block_config, - f"multi_token_prediction.blocks.{prediction_distance-2}", + f"multi_token_prediction.blocks.{prediction_distance - 2}", f"model.mtp_heads.{prediction_distance - 2}", + root_config=config, ) - converters += cls.normalization_converter_class.get_converters( + converters += cls.normalization_converter_class.emit_weight_converters( config.head.normalization, f"multi_token_prediction.heads.{prediction_distance - 2}.final_norm", f"model.mtp_norms.{prediction_distance - 1}", + root_config=config, ) return converters diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index db2b35165..c9177ebea 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,3 +1,4 @@ +import functools import typing from fast_llm.engine.checkpoint.config import CheckpointFormat @@ -5,6 +6,8 @@ ConstantImportConfigConverter, IgnoredConfigConverter, ImportOnlyConfigConverter, + KeyValueWeightConverter, + LinearWeightConverter, WeightConverter, ) from fast_llm.layers.attention.config import AttentionConfig @@ -12,14 +15,12 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( - KeyValueWeightConverter, LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, - get_weight_and_bias_converters, ) from fast_llm.utils import Assert, div @@ -71,36 +72,18 @@ def _validate_export(cls, config: AttentionConfig) -> None: Assert.is_(config.value_layer.bias.enabled, True) Assert.incl(config.dense_layer.bias.enabled, (None, False)) + # Qwen2 hardcodes Q/K/V biases on, dense bias off — independent of ``add_linear_biases`` (which is + # pinned to False on the config side because there's no HF ``attention_bias`` field). @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - True, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - True, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "query": LinearWeightConverter("query", "q_proj", bias_fn=lambda c: True), + "key_value": LinearWeightConverter( + "key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter, bias_fn=lambda c: True ), - ] + "dense": LinearWeightConverter("dense", "o_proj", bias_fn=lambda c: False), + } class Qwen2MLPConverter(LlamaMLPConverter): From 4a944debe77ebff578578575b789f4f6cf4ffe10 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 May 2026 14:20:18 -0400 Subject: [PATCH 02/12] Migrate apriel/apriel2 to declarative weight conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds ``DispatchWeightConverter`` (runtime-type dispatch on a single attribute) and ``TypedDictWeightConverter`` (per-key dispatch on a ``dict[str, Config]`` attribute) to the framework. Apriel's hybrid block-sequence uses ``BlockSequenceWeightConverter``'s ``dispatch_registry``; Apriel2 uses both new primitives — ``DispatchWeightConverter`` for the block mixer + normalization dispatch, ``TypedDictWeightConverter`` for the StochasticMixer sub-mixer fan-out. The Apriel2 Fixed/Pattern decoder section converters now contribute no weights of their own; the block fan-out runs once at the base-model level via ``BlockSequenceWeightConverter``, which already handles both shapes through its ``FixedBlockSequenceConfig`` / ``PatternBlockSequenceConfig`` dispatch. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 94 +++- fast_llm/models/gpt/conversion/apriel.py | 291 +++--------- fast_llm/models/gpt/conversion/apriel2.py | 539 ++++++---------------- 3 files changed, 304 insertions(+), 620 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 20e1863a5..8b593fd18 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1031,11 +1031,15 @@ def __init__( self, fast_llm_prefix: str, hf_prefix: str, - block_converter_class: type["ConfigSectionConverter"], + block_converter_class: type["ConfigSectionConverter"] | None = None, *, config_attr: str | None = None, dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]] | None = None, ): + Assert.custom( + lambda pair: pair[0] is not None or pair[1] is not None, + (block_converter_class, dispatch_registry), + ) super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix self._hf_prefix = hf_prefix @@ -1080,6 +1084,94 @@ def _emit( return out +class DispatchWeightConverter(WeightConverter): + """Dispatch a single sub-section converter based on the live config's runtime type. + + Reads ``getattr(config, config_attr)`` (defaults to ``fast_llm_prefix``), looks up its type in + ``registry``, and recurses into that ConfigSectionConverter with the standard extended prefixes. + Mirrors :class:`DispatchConfigConverter` on the config side. Used when a single attribute holds one + of several alternative configs (e.g. Apriel2's block ``mixer`` may be attention/mamba/gdn/kda/stochastic; + its ``normalization`` may be RMS/Layer/None). + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + registry: dict[type[Config], type["ConfigSectionConverter"]], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._registry = registry + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + sub_config = getattr(config, self._config_attr) + sub_class = self._registry[type(sub_config)] + return sub_class.emit_weight_converters( + sub_config, + f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix, + f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix), + root_config=root_config, + ) + + +class TypedDictWeightConverter(WeightConverter): + """Per-key dispatch over a ``dict[str, Config]`` attribute. + + For each entry, looks up its type in ``registry`` and recurses into that converter with names + ``{fast_llm_prefix}.{key}`` / ``{hf_prefix}.{key}``. Mirrors + :class:`TypedDictContainerConfigConverter` on the config side. Used for collections of named sub- + configs (e.g. Apriel2 StochasticMixer's ``mixers`` dict). + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + registry: dict[type[Config], type["ConfigSectionConverter"]], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._registry = registry + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + attr_dict = getattr(config, self._config_attr) + fast_llm_root = f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix + hf_root = f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix) + out: list[WeightConverter] = [] + for name, sub_config in attr_dict.items(): + sub_class = self._registry[type(sub_config)] + out += sub_class.emit_weight_converters( + sub_config, + f"{fast_llm_root}.{name}" if fast_llm_root else name, + f"{hf_root}.{name}" if hf_root else name, + root_config=root_config, + ) + return out + + class LinearWeightConverter(WeightConverter): """Bundle a linear layer's ``.weight`` and (conditionally) ``.bias`` declarations into one entry. diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index c09937aaa..6af2f151e 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -1,3 +1,4 @@ +import functools import math import typing @@ -6,11 +7,13 @@ from fast_llm.config import Config, get_nested_dict_value from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigConverter, ConfigSectionConverter, CustomConfigConverter, DefaultConfigConverter, IgnoredConfigConverter, + LinearWeightConverter, RenameConfigConverter, WeightConverter, _get_attr_path, @@ -20,13 +23,9 @@ from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import ( - effective_bias, - get_parameter_converter, - get_weight_and_bias_converters, -) +from fast_llm.models.gpt.conversion.llama import effective_bias from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, @@ -132,56 +131,23 @@ def _validate_export(cls, config: MambaConfig) -> None: Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, - config: MambaConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - # TODO: Conv - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj", - f"{hf_prefix}.in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_in_proj", - f"{hf_prefix}.dt_in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_proj", - f"{hf_prefix}.dt_proj", - effective_bias(config.dt_layer, config.add_linear_biases), - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.conv1d", - effective_bias(config.convolution_layer, config.add_linear_biases), - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.D", - f"{hf_prefix}.D", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "in_proj": LinearWeightConverter("in_proj", "in_proj"), + "dt_in_proj": LinearWeightConverter("dt_in_proj", "dt_in_proj"), + "dt_proj": LinearWeightConverter( + "dt_proj", "dt_proj", bias_fn=lambda c: effective_bias(c.dt_layer, c.add_linear_biases) + ), + "convolution": LinearWeightConverter( + "convolution", + "conv1d", + bias_fn=lambda c: effective_bias(c.convolution_layer, c.add_linear_biases), + ), + "A_log": WeightConverter("A_log", "A_log"), + "D": WeightConverter("D", "D"), + "out_proj": LinearWeightConverter("out_proj", "out_proj"), + } class GatedDeltaNetConverter(ConfigSectionConverter): @@ -223,55 +189,20 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: GatedDeltaNetConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_qkvz", - f"{hf_prefix}.in_proj_qkvz", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_ba", - f"{hf_prefix}.in_proj_ba", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.convolution", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - False, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # GDN has no linear biases — explicit ``bias_fn=lambda c: False`` since GatedDeltaNetConfig has + # no ``add_linear_biases`` field for the default to read. + no_bias = lambda c: False + return { + "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=no_bias), + "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=no_bias), + "convolution": LinearWeightConverter("convolution", "convolution", bias_fn=no_bias), + "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=no_bias), + "A_log": WeightConverter("A_log", "A_log"), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "norm": LinearWeightConverter("norm", "norm", bias_fn=no_bias), + } class KimiDeltaAttentionConverter(ConfigSectionConverter): @@ -317,103 +248,30 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: KimiDeltaAttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_proj", - f"{hf_prefix}.q_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_proj", - f"{hf_prefix}.k_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_proj", - f"{hf_prefix}.v_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_conv", - f"{hf_prefix}.q_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_conv", - f"{hf_prefix}.k_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_conv", - f"{hf_prefix}.v_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_a_proj", - f"{hf_prefix}.f_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_b_proj", - f"{hf_prefix}.f_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_a_proj", - f"{hf_prefix}.g_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_b_proj", - f"{hf_prefix}.g_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.beta_proj", - f"{hf_prefix}.beta_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.o_proj", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - False, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # KimiDeltaAttention has no linear biases. + no_bias = lambda c: False + proj_names = ( + "q_proj", + "k_proj", + "v_proj", + "q_conv", + "k_conv", + "v_conv", + "f_a_proj", + "f_b_proj", + "g_a_proj", + "g_b_proj", + "beta_proj", + "o_proj", + ) + return { + **{name: LinearWeightConverter(name, name, bias_fn=no_bias) for name in proj_names}, + "A_log": WeightConverter("A_log", "A_log"), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "norm": LinearWeightConverter("norm", "norm", bias_fn=no_bias), + } class AprielKimiDeltaAttentionBlockConverter(MistralBlockConverter): @@ -587,23 +445,18 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - decoder = config.decoder - if isinstance(decoder, FixedBlockSequenceConfig): - per_position_blocks = [decoder.block] * decoder.num_blocks - elif isinstance(decoder, PatternBlockSequenceConfig): - per_position_blocks = [decoder.blocks[block_name] for block_name in decoder.expanded_pattern] - else: - raise NotImplementedError(type(decoder).__name__) - converters = [*cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")] - for block_index, block_config in enumerate(per_position_blocks): - converters += cls.block_dispatcher_class.get_converters( - block_config, - f"decoder.{block_index}", - f"model.layers.{block_index}", - ) - converters += cls.head_converter_class.get_converters(config, exported_config) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # Override the parent's flat ``decoder`` entry with a per-position dispatch version that picks + # the right block converter from the dispatcher's registry based on the mixer's runtime type. + return { + **super()._create_weight_converters(), + "decoder": BlockSequenceWeightConverter( + "decoder", + "model.layers", + dispatch_registry=cls.block_dispatcher_class._converter_classes, + ), + } class AprielHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index d5af6b572..144acbc92 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -1,5 +1,6 @@ """Apriel2 text-only checkpoint format converter.""" +import functools import typing from transformers import PretrainedConfig @@ -7,15 +8,24 @@ from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantImportConfigConverter, CustomConfigConverter, DispatchConfigConverter, + DispatchWeightConverter, IgnoredConfigConverter, + KeyValueWeightConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, OptionalConfigConverter, + OutputProjectionWeightConverter, RenameConfigConverter, + SplitWeightConverter, + TransposeSplitWeightConverter, TypedDictContainerConfigConverter, + TypedDictWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler @@ -35,16 +45,11 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( - KeyValueWeightConverter, LlamaEmbeddingsConverter, LlamaNormalizationConverter, - MLPLayer2Converter, - SplitWeightConverter, assert_no_peft, - effective_bias, - get_parameter_converter, - get_weight_and_bias_converters, ) +from fast_llm.models.gpt.conversion.llama import effective_bias as _effective_bias from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts @@ -177,42 +182,28 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - q_bias = effective_bias(config.query_layer, config.add_linear_biases) - k_bias = effective_bias(config.key_layer, config.add_linear_biases) - v_bias = effective_bias(config.value_layer, config.add_linear_biases) - o_bias = effective_bias(config.dense_layer, config.add_linear_biases) - # k_proj and v_proj are merged in Fast-LLM's key_value layer; treat as biased only if both sides agree. - kv_bias = k_bias and v_bias - - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - q_bias, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - kv_bias, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # Each linear layer carries its own ``bias.enabled`` override; the default falls back to the + # mixer-wide ``add_linear_biases`` via :func:`_effective_bias`. ``key_value`` is biased only when + # both K and V agree (Fast-LLM packs them as a single tensor). + return { + "query": LinearWeightConverter( + "query", "q_proj", bias_fn=lambda c: _effective_bias(c.query_layer, c.add_linear_biases) + ), + "key_value": LinearWeightConverter( + "key_value", + ("k_proj", "v_proj"), + transform=KeyValueWeightConverter, + bias_fn=lambda c: ( + _effective_bias(c.key_layer, c.add_linear_biases) + and _effective_bias(c.value_layer, c.add_linear_biases) + ), ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - o_bias, - drop_on_export=drop_on_export, + "dense": LinearWeightConverter( + "dense", "o_proj", bias_fn=lambda c: _effective_bias(c.dense_layer, c.add_linear_biases) ), - ] + } def _apriel2_mamba_aux_export(config: MambaConfig) -> dict: @@ -274,55 +265,22 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: MambaConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj", - f"{hf_prefix}.in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_in_proj", - f"{hf_prefix}.dt_in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_proj", - f"{hf_prefix}.dt_proj", - config.dt_layer.bias.enabled, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.conv1d", - config.convolution_layer.bias.enabled, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.D", - f"{hf_prefix}.D", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # dt_proj and convolution read per-layer ``bias.enabled`` directly (no fallback to the mixer-wide + # flag — Apriel2's HF surfaces these biases via the dedicated ``dt_proj_bias`` / ``conv_bias`` + # auxiliary keys rather than via ``add_linear_biases``). + return { + "in_proj": LinearWeightConverter("in_proj", "in_proj"), + "dt_in_proj": LinearWeightConverter("dt_in_proj", "dt_in_proj"), + "dt_proj": LinearWeightConverter("dt_proj", "dt_proj", bias_fn=lambda c: c.dt_layer.bias.enabled), + "convolution": LinearWeightConverter( + "convolution", "conv1d", bias_fn=lambda c: c.convolution_layer.bias.enabled + ), + "A_log": WeightConverter("A_log", "A_log"), + "D": WeightConverter("D", "D"), + "out_proj": LinearWeightConverter("out_proj", "out_proj"), + } class Apriel2GatedDeltaNetConverter(ConfigSectionConverter): @@ -356,55 +314,20 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: GatedDeltaNetConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_qkvz", - f"{hf_prefix}.in_proj_qkvz", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_ba", - f"{hf_prefix}.in_proj_ba", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.convolution", - config.convolution_layer.bias.enabled, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + no_bias = lambda c: False + return { + "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=no_bias), + "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=no_bias), + "convolution": LinearWeightConverter( + "convolution", "convolution", bias_fn=lambda c: c.convolution_layer.bias.enabled + ), + "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=no_bias), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "A_log": WeightConverter("A_log", "A_log"), + "norm": NestedWeightConverter("norm", "norm", LlamaNormalizationConverter, config_attr="normalization"), + } class Apriel2KimiDeltaAttentionConverter(ConfigSectionConverter): @@ -451,103 +374,29 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: KimiDeltaAttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_proj", - f"{hf_prefix}.q_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_proj", - f"{hf_prefix}.k_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_proj", - f"{hf_prefix}.v_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_conv", - f"{hf_prefix}.q_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_conv", - f"{hf_prefix}.k_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_conv", - f"{hf_prefix}.v_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_a_proj", - f"{hf_prefix}.f_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_b_proj", - f"{hf_prefix}.f_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_a_proj", - f"{hf_prefix}.g_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_b_proj", - f"{hf_prefix}.g_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.beta_proj", - f"{hf_prefix}.beta_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.o_proj", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + no_bias = lambda c: False + proj_names = ( + "q_proj", + "k_proj", + "v_proj", + "q_conv", + "k_conv", + "v_conv", + "f_a_proj", + "f_b_proj", + "g_a_proj", + "g_b_proj", + "beta_proj", + "o_proj", + ) + return { + **{name: LinearWeightConverter(name, name, bias_fn=no_bias) for name in proj_names}, + "A_log": WeightConverter("A_log", "A_log"), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "norm": NestedWeightConverter("norm", "norm", LlamaNormalizationConverter, config_attr="normalization"), + } # Mixer dispatch registry — used inside StochasticMixer (no nested-stochastic) and at the block level. @@ -580,27 +429,9 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: StochasticMixerConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters = [] - for name, sub_mixer in config.mixers.items(): - converter_class = APRIEL2_LEAF_MIXER_REGISTRY.get(type(sub_mixer)) - if converter_class is None: - raise ValueError(f"Unknown sub-mixer type: {type(sub_mixer)}") - converters.extend( - converter_class.get_converters( - sub_mixer, - f"{fast_llm_prefix}.mixers.{name}", - f"{hf_prefix}.mixers.{name}", - drop_on_export=drop_on_export, - ) - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return {"mixers": TypedDictWeightConverter("mixers", "mixers", APRIEL2_LEAF_MIXER_REGISTRY)} # Block-level mixer registry includes StochasticMixer (which can wrap leaf mixers). @@ -693,48 +524,24 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: MLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - layer_1_bias = effective_bias(config.layer_1, config.add_linear_biases) - layer_2_bias = effective_bias(config.layer_2, config.add_linear_biases) - if config.gated: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), - layer_1_bias, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - f"{hf_prefix}.up_proj", - layer_1_bias, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``layer_1`` splits into ``(gate_proj, up_proj)`` when gated, but stays as a single ``up_proj`` + # otherwise. The transform and HF prefix both depend on the live config; resolve at emit time. + return { + "layer_1": LinearWeightConverter( + "layer_1", + lambda c: ("gate_proj", "up_proj") if c.gated else ("up_proj",), + transform=SplitWeightConverter, + bias_fn=lambda c: _effective_bias(c.layer_1, c.add_linear_biases), + ), + "layer_2": LinearWeightConverter( + "layer_2", + "down_proj", + transform=TransposeSplitWeightConverter, + bias_fn=lambda c: _effective_bias(c.layer_2, c.add_linear_biases), ), - ] + } class Apriel2BlockConverter(ConfigSectionConverter): @@ -775,49 +582,19 @@ def _validate_export(cls, config: DecoderBlockConfig) -> None: Assert.custom(lambda v: not v, config.output_scale.enabled) @classmethod - def get_converters( - cls, - config: DecoderBlockConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - mixer_converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) - if mixer_converter_class is None: - raise ValueError(f"Unknown mixer type: {type(config.mixer)}") - converters: list[WeightConverter] = list( - mixer_converter_class.get_converters( - config.mixer, - f"{fast_llm_prefix}.mixer", - f"{hf_prefix}.mixer", - drop_on_export=drop_on_export, - ) - ) - converters.extend( - Apriel2MLPConverter.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - f"{hf_prefix}.mlp", - drop_on_export=drop_on_export, - ) - ) - converters.extend( - [ - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.input_layernorm", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.post_attention_layernorm", - drop_on_export=drop_on_export, - ), - ] - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": DispatchWeightConverter("mixer", "mixer", APRIEL2_BLOCK_MIXER_REGISTRY), + "mlp": NestedWeightConverter("mlp", "mlp", Apriel2MLPConverter), + # The two state-dict norms (norm_1/norm_2) share the block's single ``normalization`` config. + "norm_1": NestedWeightConverter( + "norm_1", "input_layernorm", LlamaNormalizationConverter, config_attr="normalization" + ), + "norm_2": NestedWeightConverter( + "norm_2", "post_attention_layernorm", LlamaNormalizationConverter, config_attr="normalization" + ), + } class Apriel2FixedDecoderConverter(ConfigSectionConverter): @@ -832,23 +609,10 @@ def _create_config_converters(cls) -> dict: "block": NestedConfigConverter(("block",), cls.block_converter_class, hf_path=("block",)), } - @classmethod - def get_converters( - cls, - config: FixedBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export=drop_on_export, - ) - return converters + # The block fan-out lives on the base-model converter, which uses :class:`BlockSequenceWeightConverter` + # directly (Fixed/Pattern dispatch and block iteration share one primitive). The Fixed/Pattern decoder + # section converters exist for the config side (dispatch via :class:`DispatchConfigConverter`) and + # contribute no weights of their own. class Apriel2PatternDecoderConverter(ConfigSectionConverter): @@ -868,24 +632,7 @@ def _create_config_converters(cls) -> dict: ), } - @classmethod - def get_converters( - cls, - config: PatternBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export=drop_on_export, - ) - return converters + # See note on :class:`Apriel2FixedDecoderConverter` — block fan-out lives on the base-model converter. APRIEL2_DECODER_REGISTRY: dict[type[Config], type[ConfigSectionConverter]] = { @@ -932,25 +679,14 @@ def _validate_export(cls, config: LanguageModelHeadConfig) -> None: Assert.is_(type(config.normalization), RMSNormalizationConfig) @classmethod - def get_converters( - cls, - config: LanguageModelHeadConfig, - exported_config: dict, - fast_llm_prefix: str, - ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.final_norm", - "model.norm", - ), - get_parameter_converter( - f"{fast_llm_prefix}.output_weights", - "lm_head.weight", - drop_on_import=exported_config.get("tie_word_embeddings", False), - drop_on_export=exported_config.get("tie_word_embeddings", False), + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.norm", cls.normalization_converter_class, config_attr="normalization" ), - ] + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } class Apriel2BaseModelConverter(ConfigSectionConverter): @@ -985,15 +721,18 @@ def _create_config_converters(cls) -> dict: def _validate_export(cls, config: GPTBaseModelConfig) -> None: assert_no_peft(config) + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", Apriel2BlockConverter), + "head": NestedWeightConverter("head", "", cls.head_converter_class), + } + @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *get_apriel2_decoder_converter(config.decoder).get_converters( - config.decoder, "decoder", "model.decoder.blocks" - ), - *cls.head_converter_class.get_converters(config.head, exported_config, "head"), - ] + return cls.emit_weight_converters(config, "", "") class Apriel2HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): From fb369a0130c5d6413eb8784d3790ef9c54b049ee Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 May 2026 15:14:38 -0400 Subject: [PATCH 03/12] Migrate Llava and Apriel2-multimodal to declarative weight conversion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Last of the model-side migrations: the multimodal Llava (Pixtral vision + Mistral text) and the multimodal Apriel2 (Apriel2 vision + Apriel2 text base) handlers now declare their weight conversions via ``_create_weight_converters``. ``PatchEmbeddingWeightConverter`` is now imported from the engine (relocated earlier); the local copies in ``llava.py`` are removed. Gemma4 keeps its imperative ``get_converters`` and continues to work via the ``ConfigSectionConverter.get_converters`` shim — its helper classes don't inherit ``ConfigSectionConverter`` so they don't get a free declarative migration. Revisit in cleanup or a follow-up. Co-Authored-By: Claude Sonnet 4.6 --- .../models/multimodal/conversion/apriel2.py | 204 ++++++++---------- .../models/multimodal/conversion/llava.py | 180 ++++++---------- 2 files changed, 155 insertions(+), 229 deletions(-) diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 6fe509a9d..7ff544ee2 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -1,17 +1,24 @@ """Apriel2 multimodal checkpoint format converter.""" +import functools import typing from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantExportConfigConverter, ConstantImportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, OptionalConfigConverter, + OutputProjectionWeightConverter, + PatchEmbeddingWeightConverter, RenameConfigConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -22,27 +29,18 @@ from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.apriel2 import ( Apriel2BaseModelConverter, + Apriel2BlockConverter, Apriel2HeadConverter, Apriel2MLPConverter, Apriel2RMSNormConverter, - get_apriel2_decoder_converter, -) -from fast_llm.models.gpt.conversion.llama import ( - LlamaEmbeddingsConverter, - LlamaNormalizationConverter, - get_parameter_converter, - get_weight_and_bias_converters, ) +from fast_llm.models.gpt.conversion.llama import LlamaEmbeddingsConverter, LlamaNormalizationConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat -from fast_llm.models.multimodal.conversion.llava import ( - PatchEmbeddingWeightConverter, - PixtralAttentionConverter, -) +from fast_llm.models.multimodal.conversion.llava import PixtralAttentionConverter from fast_llm.models.multimodal.model import MultiModalModel from fast_llm.utils import Assert @@ -198,29 +196,18 @@ def _validate_export(cls, config: DecoderBlockConfig) -> None: Assert.custom(lambda v: not v, config.output_scale.enabled) @classmethod - def get_converters( - cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False - ) -> list[WeightConverter]: - return [ - *cls.mixer_converter_class.get_converters( - config.mixer, f"{fast_llm_prefix}.mixer", f"{hf_prefix}.{cls.hf_mixer_name}", drop_on_export - ), - *cls.mlp_converter_class.get_converters( - config.mlp, f"{fast_llm_prefix}.mlp", f"{hf_prefix}.{cls.hf_mlp_name}", drop_on_export - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.{cls.hf_norm_1_name}", - drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": NestedWeightConverter("mixer", cls.hf_mixer_name, cls.mixer_converter_class), + "mlp": NestedWeightConverter("mlp", cls.hf_mlp_name, cls.mlp_converter_class), + "norm_1": NestedWeightConverter( + "norm_1", cls.hf_norm_1_name, LlamaNormalizationConverter, config_attr="normalization" ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.{cls.hf_norm_2_name}", - drop_on_export, + "norm_2": NestedWeightConverter( + "norm_2", cls.hf_norm_2_name, LlamaNormalizationConverter, config_attr="normalization" ), - ] + } class Apriel2VisionEncoderConverter(ConfigSectionConverter): @@ -264,22 +251,13 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: FixedBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # The section config IS the FixedBlockSequenceConfig — use the LlamaDecoderConverter pattern of a + # custom fan-out primitive that reads ``config.block`` and ``config.num_blocks`` directly. + from fast_llm.models.gpt.conversion.llama import _FixedBlockFanoutWeightConverter + + return {"blocks": _FixedBlockFanoutWeightConverter(cls.block_converter_class)} class Apriel2EmbeddingsConverter(ConfigSectionConverter): @@ -322,21 +300,19 @@ def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) @classmethod - def get_converters( - cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.patch_embeddings", - f"{hf_prefix}.patch_embeddings", - False, - PatchEmbeddingWeightConverter, - config, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "patch_embeddings": LinearWeightConverter( + "patch_embeddings", + "patch_embeddings", + transform=PatchEmbeddingWeightConverter, + bias_fn=lambda c: False, ), - *LlamaNormalizationConverter.get_converters( - config.normalization, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.normalization" + "normalization": NestedWeightConverter( + "normalization", "normalization", LlamaNormalizationConverter, config_attr="normalization" ), - ] + } class Apriel2VisionAdapterConverter(Apriel2VisionMLPConverter): @@ -353,27 +329,12 @@ def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False - ) -> list[WeightConverter]: - from fast_llm.models.gpt.conversion.llama import MLPLayer2Converter - - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - f"{hf_prefix}.linear_1", - config.add_linear_biases, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.linear_2", - config.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter("layer_1", "linear_1"), + "layer_2": LinearWeightConverter("layer_2", "linear_2", transform=TransposeSplitWeightConverter), + } class Apriel2VisionModelConverter(ConfigSectionConverter): @@ -427,41 +388,32 @@ def _inject_rotary_metadata(config: VisionEncoderConfig) -> dict: return {} @classmethod - def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: - return [ - *cls.embeddings_converter_class.get_converters( - config.embeddings, "vision_encoder.embeddings", cls.hf_embeddings_prefix - ), - *cls.encoder_converter_class.get_converters( - config.encoder, "vision_encoder.encoder", cls.hf_encoder_prefix - ), - *cls.vision_adapter_converter_class.get_converters( - config.adapter, "vision_encoder.adapter", cls.hf_adapter_prefix + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # The vision converter normally lives under a single ``vision_encoder`` HF prefix, but Apriel2's + # state-dict places each piece at distinct absolute paths (``model.vision_encoder.{embeddings, + # encoder.blocks, adapter}``). Use ``NestedWeightConverter`` with the absolute prefix on each entry; + # the parent's :class:`NestedWeightConverter("vision_encoder", "", ...)` passes ``hf_prefix=""``, + # so the absolute prefix lands as-is. + return { + "embeddings": NestedWeightConverter( + "embeddings", cls.hf_embeddings_prefix, cls.embeddings_converter_class ), - ] + "encoder": NestedWeightConverter("encoder", cls.hf_encoder_prefix, cls.encoder_converter_class), + "adapter": NestedWeightConverter("adapter", cls.hf_adapter_prefix, cls.vision_adapter_converter_class), + } class Apriel2MultimodalHeadConverter(Apriel2HeadConverter): @classmethod - def get_converters( - cls, - config: LanguageModelHeadConfig, - exported_config: dict, - fast_llm_prefix: str, - ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.final_norm", - "model.norm", - ), - get_parameter_converter( - f"{fast_llm_prefix}.output_weights", - "lm_head.weight", - drop_on_import=exported_config.get("tie_word_embeddings", False), - drop_on_export=exported_config.get("tie_word_embeddings", False), + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.norm", cls.normalization_converter_class, config_attr="normalization" ), - ] + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): @@ -528,18 +480,30 @@ def _vision_import(hf_dict: dict) -> dict: "image_token_index": OptionalConfigConverter(("image_token_index",), ("image_token_index",)), } + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + # ``embeddings`` flat-merges into ``model``; vision_encoder writes to its own absolute prefix. + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", Apriel2BlockConverter), + "head": NestedWeightConverter("head", "", cls.head_converter_class), + } + @classmethod def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - converters: list[WeightConverter] = [] + converters = list(cls.emit_weight_converters(config, "", "")) if config.vision_encoder is not None: - converters.extend(cls.vision_model_converter_class.get_converters(config.vision_encoder)) - converters.extend(cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")) - converters.extend( - get_apriel2_decoder_converter(config.decoder).get_converters( - config.decoder, "decoder", "model.decoder.blocks" + # Vision encoder is optional — emit only when present. The Nested declaration in + # :meth:`_create_weight_converters` couldn't conditionally fire on a None attribute. + converters = ( + list( + cls.vision_model_converter_class.emit_weight_converters( + config.vision_encoder, "vision_encoder", "", root_config=config + ) + ) + + converters ) - ) - converters.extend(cls.head_converter_class.get_converters(config.head, exported_config, "head")) return converters diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index fe18ef3cb..274cea39a 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -1,17 +1,21 @@ +import functools import typing -import torch - from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantExportConfigConverter, ConstantImportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, ImportOnlyConfigConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, + PatchEmbeddingWeightConverter, RenameConfigConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -19,9 +23,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import Rotary2DConfig -from fast_llm.layers.block.config import FixedBlockSequenceConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.llama import ( _TRANSFORMERS_V4, @@ -29,15 +31,11 @@ LlamaBlockConverter, LlamaDecoderConverter, LlamaNormalizationConverter, - MLPLayer2Converter, - get_parameter_converter, - get_weight_and_bias_converters, ) from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralHeadConverter, MistralMLPConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel -from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div @@ -123,34 +121,6 @@ class PixtralEncoderConverter(LlamaDecoderConverter): block_converter_class: typing.ClassVar[type[PixtralBlockConverter]] = PixtralBlockConverter -class PatchEmbeddingWeightConverter(WeightConverter): - _config: PatchEmbeddingsConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return tuple( - weight_[:].view( - *weight_[:].shape[:-1], - self._config.input_channels, - self._config.patch_height, - self._config.patch_width, - ) - for weight_ in weight - ) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return tuple( - weight_[:].view( - *weight_[:].shape[:-3], - self._config.input_channels * self._config.patch_height * self._config.patch_width, - ) - for weight_ in weight - ) - - class PixtralEmbeddingsConverter(ConfigSectionConverter): """Converts ``PatchEmbeddingsConfig`` ↔ Pixtral HF flat fields (``patch_size`` / ``num_channels``). @@ -184,21 +154,24 @@ def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) @classmethod - def get_converters( - cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.patch_embeddings", - f"{hf_prefix}.patch_conv", - False, - PatchEmbeddingWeightConverter, - config, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "patch_embeddings": LinearWeightConverter( + "patch_embeddings", + "patch_conv", + transform=PatchEmbeddingWeightConverter, + bias_fn=lambda c: False, ), - *cls.normalization_converter_class.get_converters( - config, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.ln_pre" + # ``PixtralEmbeddingsConverter``'s section config IS the ``PatchEmbeddingsConfig`` (carries the + # normalization sub-config directly), so the nested ``LlamaNormalizationConverter`` reads from + # ``config_attr="normalization"`` — but the original code passed the *parent* config in. Mirror + # that by reading ``self`` (config_attr=""): the norm converter only needs ``.weight`` and the + # parent already exposes that field directly. + "normalization": NestedWeightConverter( + "normalization", "ln_pre", cls.normalization_converter_class, config_attr="normalization" ), - ] + } class LlavaVisionAdapterConverter(ConfigSectionConverter): @@ -243,21 +216,12 @@ def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - f"{hf_prefix}.linear_1", - config.add_linear_biases, - WeightConverter, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.linear_2", - config.add_linear_biases, - MLPLayer2Converter, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter("layer_1", "linear_1"), + "layer_2": LinearWeightConverter("layer_2", "linear_2", transform=TransposeSplitWeightConverter), + } class LlavaVisionModelConverter(ConfigSectionConverter): @@ -315,39 +279,36 @@ def _validate_export(cls, config: VisionEncoderConfig) -> None: Assert.eq(mixer.head_size * mixer.heads, config.hidden_size) @classmethod - def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: - return [ - *cls.embeddings_converter_class.get_converters( - config.embeddings, "vision_encoder.embeddings", "vision_tower" - ), - *cls.encoder_converter_class.get_converters( - config.encoder, "vision_encoder.encoder", "vision_tower.transformer.layers" - ), - *LlavaVisionAdapterConverter.get_converters( - config.adapter, "vision_encoder.adapter", "multi_modal_projector" + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "embeddings": NestedWeightConverter("embeddings", "vision_tower", cls.embeddings_converter_class), + # The encoder section IS a FixedBlockSequenceConfig — fan out blocks via the dedicated primitive. + "encoder": NestedWeightConverter( + "encoder", "vision_tower.transformer.layers", cls.encoder_converter_class ), - ] + "adapter": NestedWeightConverter("adapter", "multi_modal_projector", LlavaVisionAdapterConverter), + } class LlavaHeadConverter(MistralHeadConverter): + # Llava always writes ``lm_head.weight`` on export (never dropped, even when ``tied_embedding_weight=True``); + # the parent's :class:`OutputProjectionWeightConverter` would also drop on export, so we replace it with a + # plain rename. When the HF state-dict lacks ``lm_head.weight`` (tied case), the handler's per-converter + # ``all(name in state_dict)`` check makes the rename a silent no-op on import — equivalent to the previous + # ``drop_on_import=tied`` behaviour, without the extra parameter plumbing. @classmethod - def get_converters( - cls, - config: LanguageModelConfig, - exported_config: dict, - ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.head.normalization, - f"head.final_norm", - f"language_model.model.norm", - ), - get_parameter_converter( - f"head.output_weights", - "language_model.lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "final_norm": NestedWeightConverter( + "final_norm", + "language_model.model.norm", + cls.normalization_converter_class, + config_attr="normalization", ), - ] + "output_weights": WeightConverter("output_weights", "language_model.lm_head.weight"), + } class LlavaLanguageModelConverter(MistralBaseModelConverter): @@ -428,26 +389,27 @@ def _validate_export(cls, config: MultiModalBaseModelConfig) -> None: assert config.image_token_index is not None, "Llava requires an image_token_index" @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: text_base_cls = cls.language_model_converter_class - decoder_config = config.decoder - block_config = ( - decoder_config.block - if isinstance(decoder_config, FixedBlockSequenceConfig) - else next(iter(decoder_config.blocks.values())) - ) - block_converters: list[WeightConverter] = [] - for block_index in range(decoder_config.num_blocks): - block_converters += text_base_cls.block_converter_class.get_converters( - block_config, f"decoder.{block_index}", f"language_model.model.layers.{block_index}" - ) - return [ - *cls.vision_model_converter_class.get_converters(config.vision_encoder), - *text_base_cls.embeddings_converter_class.get_converters( - config.embeddings, "embeddings", "language_model.model" + return { + "vision_encoder": NestedWeightConverter("vision_encoder", "", cls.vision_model_converter_class), + "embeddings": NestedWeightConverter( + "embeddings", "language_model.model", text_base_cls.embeddings_converter_class ), - *block_converters, - *text_base_cls.head_converter_class.get_converters(config, {"tie_word_embeddings": False}), + "decoder": BlockSequenceWeightConverter( + "decoder", "language_model.model.layers", text_base_cls.block_converter_class + ), + } + + @classmethod + def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + # ``head`` is added at the aggregator level because the LlavaHead's plain WeightConverter for + # ``language_model.lm_head.weight`` doesn't fit a NestedWeightConverter under any HF prefix — + # it lives at the HF root, not inside ``language_model.model``. + return [ + *cls.emit_weight_converters(config, "", ""), + *cls.language_model_converter_class.head_converter_class.get_converters(config, exported_config), ] From 890fc864682fbb25c8b08f34e1bb7a28d5e89ef1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 May 2026 16:27:53 -0400 Subject: [PATCH 04/12] Drop legacy weight-conversion helpers and shims MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes the post-migration deadweight: * ``ConfigSectionConverter.get_converters`` (section-shape shim) — every consumer now calls ``emit_weight_converters`` directly. * ``get_parameter_converter``, ``get_weight_and_bias_converters``, ``MLPLayer2Converter`` alias in llama.py — no remaining callers. * ``drop_on_export`` parameter plumbing throughout gemma4 — the only legitimate use case (head tied embeddings) is handled by ``OutputProjectionWeightConverter`` at the walker level. Gemma4 gains a local ``_linear_converters`` helper that builds ``.weight`` and (optional) ``.bias`` ``WeightConverter`` instances directly — Gemma4's helper classes don't inherit ``ConfigSectionConverter`` so the ``LinearWeightConverter`` declarative primitive doesn't apply. ``effective_bias`` stays in llama.py as a published helper — still used by Apriel/Apriel2 config-side ``CustomConfigConverter`` export_fns and the matching ``LinearWeightConverter.bias_fn`` lambdas. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 18 --- fast_llm/models/gpt/conversion/gemma4.py | 146 ++++++++++------------- fast_llm/models/gpt/conversion/llama.py | 80 +------------ 3 files changed, 68 insertions(+), 176 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 8b593fd18..5ac7004c9 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -619,24 +619,6 @@ def import_config(cls, hf_dict: dict) -> dict: out = {"type": fast_llm_type, **out} return out - @classmethod - def get_converters( - cls, - config: Config, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list["WeightConverter"]: - """Imperative-shape entry point — delegates to the declarative walker. - - Section converters that haven't migrated override this with a custom body; migrated sections leave - it inherited. The ``drop_on_export`` parameter is accepted for signature compatibility with the - pre-migration shape but is unused — the walker handles tied embeddings via - :class:`OutputProjectionWeightConverter`. Once every consumer is migrated this shim and the - parameter are removed. - """ - return cls.emit_weight_converters(config, fast_llm_prefix, hf_prefix) - @classmethod def emit_weight_converters( cls, diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 6cbb7898c..1e5e54468 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -9,8 +9,10 @@ ConstantExportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, + KeyValueWeightConverter, RenameConfigConverter, SplitWeightConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -28,17 +30,48 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Gemma4CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( - KeyValueWeightConverter, LlamaEmbeddingsConverter, LlamaHeadConverter, LlamaNormalizationConverter, - MLPLayer2Converter, - get_parameter_converter, - get_weight_and_bias_converters, ) from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts + +def _linear_converters( + fast_llm_prefix: str, + hf_prefix: str | tuple[str, ...], + use_bias: bool, + transform: type[WeightConverter] = WeightConverter, + config=None, +) -> list[WeightConverter]: + """Local helper: build ``.weight`` and (conditional) ``.bias`` converters for one linear layer. + + Gemma4's helper classes don't inherit ``ConfigSectionConverter``, so the + :class:`LinearWeightConverter` declarative primitive doesn't apply directly — this is the + smallest helper that covers the Gemma4-specific imperative ``get_converters`` shape. ``config`` + is forwarded to the transform constructor when the transform captures it + (e.g. :class:`KeyValueWeightConverter`). + """ + hf_names = (hf_prefix,) if isinstance(hf_prefix, str) else tuple(hf_prefix) + converters = [ + transform( + f"{fast_llm_prefix}.weight", + tuple(f"{name}.weight" for name in hf_names), + config, + ) + ] + if use_bias: + converters.append( + transform( + f"{fast_llm_prefix}.bias", + tuple(f"{name}.bias" for name in hf_names), + config, + ) + ) + return converters + + _SLIDING_ATTENTION = "sliding_attention" _FULL_ATTENTION = "full_attention" @@ -159,53 +192,46 @@ def get_converters( config: AttentionConfig, fast_llm_prefix: str, hf_prefix: str, - drop_on_export: bool = False, ) -> list[WeightConverter]: if config.shared_key_value: # K=V: single k_proj reused as value; no v_proj in HF - kv_converters = get_weight_and_bias_converters( + kv_converters = _linear_converters( f"{fast_llm_prefix}.key_value", f"{hf_prefix}.k_proj", False, - drop_on_export=drop_on_export, ) else: - kv_converters = get_weight_and_bias_converters( + kv_converters = _linear_converters( f"{fast_llm_prefix}.key_value", (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), False, KeyValueWeightConverter, config, - drop_on_export=drop_on_export, ) converters = [ - *get_weight_and_bias_converters( + *_linear_converters( f"{fast_llm_prefix}.query", f"{hf_prefix}.q_proj", False, - drop_on_export=drop_on_export, ), *kv_converters, - *get_weight_and_bias_converters( + *_linear_converters( f"{fast_llm_prefix}.dense", f"{hf_prefix}.o_proj", False, - drop_on_export=drop_on_export, ), ] if config.query_norm is not None: - converters += LlamaNormalizationConverter.get_converters( + converters += LlamaNormalizationConverter.emit_weight_converters( config.query_norm, f"{fast_llm_prefix}.query_norm", f"{hf_prefix}.q_norm", - drop_on_export=drop_on_export, ) if config.key_norm is not None: - converters += LlamaNormalizationConverter.get_converters( + converters += LlamaNormalizationConverter.emit_weight_converters( config.key_norm, f"{fast_llm_prefix}.key_norm", f"{hf_prefix}.k_norm", - drop_on_export=drop_on_export, ) # value_norm is FixedRMSNorm — no learnable weight to convert return converters @@ -239,22 +265,19 @@ def get_converters( config: MLPConfig, fast_llm_prefix: str, hf_prefix: str, - drop_on_export: bool = False, ) -> list[WeightConverter]: return [ - *get_weight_and_bias_converters( + *_linear_converters( f"{fast_llm_prefix}.layer_1", (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), False, SplitWeightConverter, - drop_on_export=drop_on_export, ), - *get_weight_and_bias_converters( + *_linear_converters( f"{fast_llm_prefix}.layer_2", f"{hf_prefix}.down_proj", False, - MLPLayer2Converter, - drop_on_export=drop_on_export, + TransposeSplitWeightConverter, ), ] @@ -311,39 +334,17 @@ def get_converters( config: MoEMLPConfig, fast_llm_prefix: str, hf_prefix: str, - drop_on_export: bool = False, ) -> list[WeightConverter]: converters = [ - *get_weight_and_bias_converters( + *_linear_converters( f"{fast_llm_prefix}.router", f"{hf_prefix}.router.proj", False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.router_scale", - f"{hf_prefix}.router.scale", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.router_per_expert_scale", - f"{hf_prefix}.router.per_expert_scale", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.layer_1.weight", - f"{hf_prefix}.experts.gate_up_proj", - Gemma4MoELayer1Converter, - config, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.layer_2.weight", - f"{hf_prefix}.experts.down_proj", - Gemma4MoELayer2Converter, - config, - drop_on_export=drop_on_export, ), + WeightConverter(f"{fast_llm_prefix}.router_scale", f"{hf_prefix}.router.scale"), + WeightConverter(f"{fast_llm_prefix}.router_per_expert_scale", f"{hf_prefix}.router.per_expert_scale"), + Gemma4MoELayer1Converter(f"{fast_llm_prefix}.layer_1.weight", f"{hf_prefix}.experts.gate_up_proj", config), + Gemma4MoELayer2Converter(f"{fast_llm_prefix}.layer_2.weight", f"{hf_prefix}.experts.down_proj", config), ] # router.norm is FixedRMSNorm — no learnable weight to convert. return converters @@ -380,44 +381,37 @@ def get_converters( config: HybridMoEMLPConfig, fast_llm_prefix: str, hf_prefix: str, - drop_on_export: bool = False, ) -> list[WeightConverter]: return [ *Gemma4MLPConverter.get_converters( config.dense, f"{fast_llm_prefix}.dense", f"{hf_prefix}.mlp", - drop_on_export=drop_on_export, ), *Gemma4MoEMLPConverter.get_converters( config.routed, f"{fast_llm_prefix}.routed", hf_prefix, - drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.dense.pre_norm, f"{fast_llm_prefix}.dense.pre_norm", f"{hf_prefix}.pre_feedforward_layernorm", - drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.dense.post_norm, f"{fast_llm_prefix}.dense.post_norm", f"{hf_prefix}.post_feedforward_layernorm_1", - drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.routed.pre_norm, f"{fast_llm_prefix}.routed.pre_norm", f"{hf_prefix}.pre_feedforward_layernorm_2", - drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.routed.post_norm, f"{fast_llm_prefix}.routed.post_norm", f"{hf_prefix}.post_feedforward_layernorm_2", - drop_on_export=drop_on_export, ), ] @@ -476,7 +470,6 @@ def get_converters( config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, - drop_on_export: bool = False, ) -> list[WeightConverter]: is_moe = isinstance(config.mlp, HybridMoEMLPConfig) converters = [ @@ -484,7 +477,6 @@ def get_converters( config.mixer, f"{fast_llm_prefix}.mixer", f"{hf_prefix}.self_attn", - drop_on_export=drop_on_export, ), ] if is_moe: @@ -492,48 +484,36 @@ def get_converters( config.mlp, f"{fast_llm_prefix}.mlp", hf_prefix, - drop_on_export=drop_on_export, ) else: converters += Gemma4MLPConverter.get_converters( config.mlp, f"{fast_llm_prefix}.mlp", f"{hf_prefix}.mlp", - drop_on_export=drop_on_export, ) - converters += LlamaNormalizationConverter.get_converters( + converters += LlamaNormalizationConverter.emit_weight_converters( config.normalization, f"{fast_llm_prefix}.norm_2", f"{hf_prefix}.pre_feedforward_layernorm", - drop_on_export=drop_on_export, ) converters += [ - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.normalization, f"{fast_llm_prefix}.norm_1", f"{hf_prefix}.input_layernorm", - drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.post_mixer_normalization, f"{fast_llm_prefix}.post_mixer_norm", f"{hf_prefix}.post_attention_layernorm", - drop_on_export=drop_on_export, ), - *LlamaNormalizationConverter.get_converters( + *LlamaNormalizationConverter.emit_weight_converters( config.post_mlp_normalization, f"{fast_llm_prefix}.post_mlp_norm", f"{hf_prefix}.post_feedforward_layernorm", - drop_on_export=drop_on_export, ), ] - converters.append( - get_parameter_converter( - f"{fast_llm_prefix}.output_scale", - f"{hf_prefix}.layer_scalar", - drop_on_export=drop_on_export, - ) - ) + converters.append(WeightConverter(f"{fast_llm_prefix}.output_scale", f"{hf_prefix}.layer_scalar")) return converters @@ -578,7 +558,6 @@ def get_converters( config: PatternBlockSequenceConfig, fast_llm_prefix: str, hf_prefix: str, - drop_on_export: bool = False, ) -> list[WeightConverter]: Assert.custom(isinstance, config, PatternBlockSequenceConfig) converters = [] @@ -588,7 +567,6 @@ def get_converters( block_config, f"{fast_llm_prefix}.{block_index}", f"{hf_prefix}.{block_index}", - drop_on_export=drop_on_export, ) return converters @@ -756,7 +734,9 @@ def _head_import(hf_dict: dict) -> dict: @classmethod def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), + *cls.embeddings_converter_class.emit_weight_converters( + config.embeddings, "embeddings", "model", root_config=config + ), *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), *cls.head_converter_class.get_converters(config, exported_config), ] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 29373ae74..546fc4894 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -56,82 +56,12 @@ def assert_no_peft(config: GPTBaseModelConfig) -> None: def effective_bias(layer_config: AffineLinearConfig, default: bool) -> bool: - """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default.""" - return default if layer_config.bias.enabled is None else layer_config.bias.enabled - - -# ============================================================ -# Legacy imperative weight-conversion helpers — kept for converters that haven't migrated to the -# declarative shape yet. Deleted in the final cleanup commit once every consumer is on -# ``_create_weight_converters``. -# ============================================================ - + """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default. -# Re-export under the legacy name so callers that import ``MLPLayer2Converter`` from this module keep -# working during migration. -MLPLayer2Converter = TransposeSplitWeightConverter - - -def get_parameter_converter( - fast_llm_name: str | tuple[str, ...], - hf_name: str | tuple[str, ...], - cls=WeightConverter, - config=None, - drop_on_export: bool = False, - drop_on_import: bool = False, -) -> WeightConverter: - from fast_llm.engine.checkpoint.external import IgnoreExportWeightConverter, IgnoreImportWeightConverter - - if isinstance(fast_llm_name, str): - fast_llm_name = (fast_llm_name,) - if isinstance(hf_name, str): - hf_name = (hf_name,) - if drop_on_export: - cls = IgnoreExportWeightConverter - if drop_on_import: - cls = IgnoreImportWeightConverter - return cls( - () if drop_on_import else fast_llm_name, - () if drop_on_export else hf_name, - config, - ) - - -def get_weight_and_bias_converters( - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - config=None, - drop_on_export: bool = False, - drop_on_import: bool = False, -) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - get_parameter_converter( - () if drop_on_import else tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - () if drop_on_export else tuple(f"{prefix}.weight" for prefix in hf_prefix), - cls, - config, - drop_on_export, - drop_on_import, - ) - ] - if use_bias: - converters.append( - get_parameter_converter( - () if drop_on_import else tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - () if drop_on_export else tuple(f"{prefix}.bias" for prefix in hf_prefix), - cls, - config, - drop_on_export, - drop_on_import, - ) - ) - return converters + Used by Apriel and Apriel2 config-side ``CustomConfigConverter`` export_fns (which need to translate + a per-layer override into an HF-side bias flag) and by their ``LinearWeightConverter.bias_fn`` lambdas. + """ + return default if layer_config.bias.enabled is None else layer_config.bias.enabled # ============================================================ From 239906e4cb39a74f2be92f63346c9dd60d76ff61 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 May 2026 11:08:05 -0400 Subject: [PATCH 05/12] Address review: dead code, dead plumbing, framework polish MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Drop ``AprielBlockConverter.get_converters`` — calls a now-nonexistent ``.get_converters`` on the block-converter registry values and is unreachable in practice (dispatch goes through ``BlockSequenceWeightConverter(dispatch_registry=...)``). * Drop the unused ``block_converter_class`` ClassVar from Apriel/Mistral/Qwen2/Mixtral head converters — only MTP-Llama's head reads it (kept on ``LlamaHeadConverter``). * Drop the ``exported_config`` parameter throughout: no surviving ``get_converters`` override reads it, and the ``__init__`` ``_export_config(model.config)`` precompute it powered is gone. Tied-embedding handling lives on ``OutputProjectionWeightConverter``. * Fold ``_FixedBlockFanoutWeightConverter`` into ``BlockSequenceWeightConverter`` via a ``config_attr=""`` sentinel for "section IS the block sequence" — kills the cross-package private import from ``llama.py`` into ``multimodal/apriel2.py``. * ``LinearWeightConverter.bias_fn`` and ``OutputProjectionWeightConverter._emit`` use direct attribute access instead of ``getattr(..., default)`` — misuse now surfaces as ``AttributeError`` rather than silently falling back to ``False``. * Tighten ``BlockSequenceWeightConverter``'s assertion to XOR — passing both ``block_converter_class`` and ``dispatch_registry`` no longer silently ignores the former. * Extract ``_join_prefix(parent, own)`` helper for the empty-handling rule shared across Nested/BlockSequence/Dispatch/TypedDict ``_emit`` methods. * Apriel2 base + multimodal aggregators get a ``block_converter_class`` ClassVar (matches ``LlamaBaseModelConverter``) instead of hardcoding ``Apriel2BlockConverter`` inline. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 52 +++++++++++++----- fast_llm/engine/checkpoint/huggingface.py | 8 +-- fast_llm/models/gpt/conversion/apriel.py | 15 +----- fast_llm/models/gpt/conversion/apriel2.py | 5 +- fast_llm/models/gpt/conversion/gemma4.py | 4 +- fast_llm/models/gpt/conversion/llama.py | 53 ++++++------------- fast_llm/models/gpt/conversion/mistral.py | 2 +- fast_llm/models/gpt/conversion/mixtral.py | 2 +- fast_llm/models/gpt/conversion/mtp_llama.py | 1 - fast_llm/models/gpt/conversion/qwen2.py | 2 +- .../models/multimodal/conversion/apriel2.py | 15 +++--- .../models/multimodal/conversion/llava.py | 4 +- 12 files changed, 74 insertions(+), 89 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 5ac7004c9..71138db68 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -761,6 +761,18 @@ def _prepend_prefix(prefix: str, names: tuple[str, ...]) -> tuple[str, ...]: return tuple(f"{prefix}.{name}" for name in names) +def _join_prefix(parent: str, own: str) -> str: + """Join two dot-separated prefixes, tolerating either being empty. + + Structural primitives (Nested/BlockSequence/Dispatch/TypedDict) call this when building the + descended ``(fast_llm_prefix, hf_prefix)`` for the recursive call — either side can legitimately + be empty (the section converter sits at the HF root, or the declaration's own prefix is empty). + """ + if parent and own: + return f"{parent}.{own}" + return parent or own + + class WeightConverter: """Leaf weight-conversion declaration / emitted instance. @@ -950,7 +962,7 @@ def _emit( *, root_config: Config, ) -> list[WeightConverter]: - if getattr(root_config, "tied_embedding_weight", False): + if root_config.tied_embedding_weight: return [] return super()._emit(config, fast_llm_prefix, hf_prefix, root_config=root_config) @@ -992,8 +1004,8 @@ def _emit( sub_config = getattr(config, self._config_attr) return self._sub_converter_class.emit_weight_converters( sub_config, - f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix, - f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix), + _join_prefix(fast_llm_prefix, self._fast_llm_prefix), + _join_prefix(hf_prefix, self._hf_prefix), root_config=root_config, ) @@ -1007,6 +1019,15 @@ class BlockSequenceWeightConverter(WeightConverter): Handles both ``FixedBlockSequenceConfig`` (single repeated block) and ``PatternBlockSequenceConfig`` (per-position blocks indexed via ``decoder.expanded_pattern``). + + ``config_attr`` selects how the block sequence is reached from the parent config: + + * default (``None``) — read ``getattr(parent, fast_llm_prefix)``. + * explicit string — read ``getattr(parent, config_attr)``. + * empty string ``""`` — the *section* config is itself the block sequence (no parent attribute); + used when ``BlockSequenceWeightConverter`` is declared by a section converter whose + ``fast_llm_config_class`` is a ``FixedBlockSequenceConfig`` directly (e.g. + ``LlamaDecoderConverter`` plugged into the Pixtral vision encoder; Apriel2's vision encoder). """ def __init__( @@ -1018,8 +1039,11 @@ def __init__( config_attr: str | None = None, dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]] | None = None, ): + # Exactly one of the two must be set: the single-class path uses ``block_converter_class``; + # the per-position-type-dispatch path uses ``dispatch_registry``. Passing both would silently + # ignore ``block_converter_class`` since ``_emit`` prefers the registry. Assert.custom( - lambda pair: pair[0] is not None or pair[1] is not None, + lambda pair: (pair[0] is None) != (pair[1] is None), (block_converter_class, dispatch_registry), ) super().__init__((), ()) @@ -1040,7 +1064,7 @@ def _emit( # Lazy import to keep external.py free of layers/ dependencies. from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - block_sequence = getattr(config, self._config_attr) + block_sequence = config if self._config_attr == "" else getattr(config, self._config_attr) if isinstance(block_sequence, FixedBlockSequenceConfig): per_position_blocks = [block_sequence.block] * block_sequence.num_blocks elif isinstance(block_sequence, PatternBlockSequenceConfig): @@ -1048,8 +1072,8 @@ def _emit( else: raise NotImplementedError(type(block_sequence).__name__) - fast_llm_root = f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix - hf_root = f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix) + fast_llm_root = _join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = _join_prefix(hf_prefix, self._hf_prefix) out: list[WeightConverter] = [] for index, block in enumerate(per_position_blocks): block_class = ( @@ -1102,8 +1126,8 @@ def _emit( sub_class = self._registry[type(sub_config)] return sub_class.emit_weight_converters( sub_config, - f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix, - f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix), + _join_prefix(fast_llm_prefix, self._fast_llm_prefix), + _join_prefix(hf_prefix, self._hf_prefix), root_config=root_config, ) @@ -1140,15 +1164,15 @@ def _emit( root_config: Config, ) -> list[WeightConverter]: attr_dict = getattr(config, self._config_attr) - fast_llm_root = f"{fast_llm_prefix}.{self._fast_llm_prefix}" if fast_llm_prefix else self._fast_llm_prefix - hf_root = f"{hf_prefix}.{self._hf_prefix}" if hf_prefix and self._hf_prefix else (hf_prefix or self._hf_prefix) + fast_llm_root = _join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = _join_prefix(hf_prefix, self._hf_prefix) out: list[WeightConverter] = [] for name, sub_config in attr_dict.items(): sub_class = self._registry[type(sub_config)] out += sub_class.emit_weight_converters( sub_config, - f"{fast_llm_root}.{name}" if fast_llm_root else name, - f"{hf_root}.{name}" if hf_root else name, + _join_prefix(fast_llm_root, name), + _join_prefix(hf_root, name), root_config=root_config, ) return out @@ -1175,7 +1199,7 @@ def __init__( hf_prefix: str | tuple[str, ...] | typing.Callable[[Config], str | tuple[str, ...]], *, transform: type[WeightConverter] = WeightConverter, - bias_fn: typing.Callable[[Config], bool] = lambda c: getattr(c, "add_linear_biases", False), + bias_fn: typing.Callable[[Config], bool] = lambda c: c.add_linear_biases, ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 9074e72fc..864a21300 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -36,7 +36,7 @@ def export_config(cls, config: BaseModelConfig) -> dict: @classmethod @abc.abstractmethod - def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: pass @@ -44,10 +44,6 @@ class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, architecture: typing.ClassVar[str] base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] - def __init__(self, model: "FastLLMModel"): - self._exported_config = self._export_config(model.config) - super().__init__(model) - @classmethod @abc.abstractmethod def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: @@ -180,7 +176,7 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: - return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config) + return self.base_model_converter_class.get_converters(self._model.config.base_model) def _load_weights( self, config: CheckpointLoadConfig, device diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 6af2f151e..7f22db885 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -21,7 +21,6 @@ ) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat @@ -400,21 +399,9 @@ class AprielBlockConverter: GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } - @classmethod - def get_converters( - cls, - config: DecoderBlockConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return cls._converter_classes[type(config.mixer)].get_converters( - config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export - ) - class AprielHeadConverter(MistralHeadConverter): - block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter + pass class AprielBaseModelConverter(MistralBaseModelConverter): diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index 144acbc92..a6f51104c 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -693,6 +693,7 @@ class Apriel2BaseModelConverter(ConfigSectionConverter): fast_llm_config_class = GPTBaseModelConfig embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod @@ -726,12 +727,12 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), - "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", Apriel2BlockConverter), + "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", cls.block_converter_class), "head": NestedWeightConverter("head", "", cls.head_converter_class), } @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return cls.emit_weight_converters(config, "", "") diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 1e5e54468..752b0acdd 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -732,13 +732,13 @@ def _head_import(hf_dict: dict) -> dict: } @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ *cls.embeddings_converter_class.emit_weight_converters( config.embeddings, "embeddings", "model", root_config=config ), *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), - *cls.head_converter_class.get_converters(config, exported_config), + *cls.head_converter_class.get_converters(config), ] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 546fc4894..de71a13c6 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -362,35 +362,13 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # ``BlockSequenceWeightConverter`` works on parent-level (decoder lives one level up); here the - # section IS the block sequence, so emit a flat fan-out keyed by block index. - # Used only by Pixtral's vision encoder (the standard text formats inline the block dispatch at - # the base-model converter instead). - return {"blocks": _FixedBlockFanoutWeightConverter(cls.block_converter_class)} - - -class _FixedBlockFanoutWeightConverter(WeightConverter): - """Emit one set of block-sub-converter declarations per position of a ``FixedBlockSequenceConfig``. - - Lives here because ``LlamaDecoderConverter``'s section config *is* the block sequence — there is no - parent attribute to read via the generic :class:`BlockSequenceWeightConverter` shape. - """ - - def __init__(self, block_converter_class: type[ConfigSectionConverter]): - super().__init__((), ()) - self._block_converter_class = block_converter_class - - def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): - Assert.is_(type(config), FixedBlockSequenceConfig) - out: list[WeightConverter] = [] - for index in range(config.num_blocks): - out += self._block_converter_class.emit_weight_converters( - config.block, - f"{fast_llm_prefix}.{index}" if fast_llm_prefix else str(index), - f"{hf_prefix}.{index}" if hf_prefix else str(index), - root_config=root_config, - ) - return out + # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it) — + # ``config_attr=""`` tells ``BlockSequenceWeightConverter`` to read the section config directly. + # Used by Pixtral's vision encoder and Apriel2's vision encoder; text formats inline the dispatch + # at the base-model converter instead. + return { + "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, config_attr=""), + } class LlamaEmbeddingsConverter(ConfigSectionConverter): @@ -459,12 +437,11 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: def get_converters( cls, config: GPTBaseModelConfig, - exported_config: dict, ) -> list[WeightConverter]: - """Aggregator-shape shim: non-migrated base-model converters pass the full - :class:`GPTBaseModelConfig` plus the exported HF dict. Translates to the declarative walker — - the tied-embedding handling now lives on :class:`OutputProjectionWeightConverter` and reads - ``root_config.tied_embedding_weight`` directly, so ``exported_config`` is unused. + """Aggregator entry-point: the base-model converter passes the full :class:`GPTBaseModelConfig` + so subclasses (e.g. MTP-Llama) can read ``config.decoder.last_block_config`` / + ``config.head.prediction_heads`` when extending the head's weights. Tied-embedding handling + lives on :class:`OutputProjectionWeightConverter` and reads ``root_config.tied_embedding_weight``. """ return cls.emit_weight_converters(config.head, "head", "", root_config=config) @@ -528,18 +505,18 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: # ``head`` is added at the aggregator level (in :meth:`get_converters`) because the head - # converter's signature takes the root config + the exported HF dict so subclasses can - # extend it (e.g. MTP-Llama fans out per-prediction-head blocks and norms). + # converter takes the full base-model config so subclasses can extend it (e.g. MTP-Llama + # fans out per-prediction-head blocks and norms). return { "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), "decoder": BlockSequenceWeightConverter("decoder", "model.layers", cls.block_converter_class), } @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ *cls.emit_weight_converters(config, "", ""), - *cls.head_converter_class.get_converters(config, exported_config), + *cls.head_converter_class.get_converters(config), ] diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 18251c760..ac42b4bf9 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -40,7 +40,7 @@ class MistralBlockConverter(LlamaBlockConverter): class MistralHeadConverter(LlamaHeadConverter): - block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter + pass class MistralBaseModelConverter(LlamaBaseModelConverter): diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 800d0973a..00b6f28ff 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -84,7 +84,7 @@ class MixtralBlockConverter(MistralBlockConverter): class MixtralHeadConverter(MistralHeadConverter): - block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter + pass class MixtralBaseModelConverter(MistralBaseModelConverter): diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 9c5c90c7e..d76e260da 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -46,7 +46,6 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: def get_converters( cls, config: GPTBaseModelConfig, - exported_config: dict, ) -> list[WeightConverter]: converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config)) # Append the MTP fan-out: one block + one norm per extra prediction head. ``block_converter_class`` diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index c9177ebea..d932c4f5e 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -102,7 +102,7 @@ class Qwen2BlockConverter(LlamaBlockConverter): class Qwen2HeadConverter(LlamaHeadConverter): - block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter + pass def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 7ff544ee2..f96f3f365 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -253,11 +253,11 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # The section config IS the FixedBlockSequenceConfig — use the LlamaDecoderConverter pattern of a - # custom fan-out primitive that reads ``config.block`` and ``config.num_blocks`` directly. - from fast_llm.models.gpt.conversion.llama import _FixedBlockFanoutWeightConverter - - return {"blocks": _FixedBlockFanoutWeightConverter(cls.block_converter_class)} + # The section config IS the FixedBlockSequenceConfig — ``config_attr=""`` makes + # BlockSequenceWeightConverter read the section config directly instead of via ``getattr``. + return { + "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, config_attr=""), + } class Apriel2EmbeddingsConverter(ConfigSectionConverter): @@ -431,6 +431,7 @@ class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBas text_base_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BaseModelConverter vision_model_converter_class: typing.ClassVar[type[Apriel2VisionModelConverter]] = Apriel2VisionModelConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter @classmethod @@ -486,12 +487,12 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { # ``embeddings`` flat-merges into ``model``; vision_encoder writes to its own absolute prefix. "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), - "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", Apriel2BlockConverter), + "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", cls.block_converter_class), "head": NestedWeightConverter("head", "", cls.head_converter_class), } @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: MultiModalBaseModelConfig) -> list[WeightConverter]: converters = list(cls.emit_weight_converters(config, "", "")) if config.vision_encoder is not None: # Vision encoder is optional — emit only when present. The Nested declaration in diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 274cea39a..915a93ea9 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -403,13 +403,13 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: } @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: MultiModalBaseModelConfig) -> list[WeightConverter]: # ``head`` is added at the aggregator level because the LlavaHead's plain WeightConverter for # ``language_model.lm_head.weight`` doesn't fit a NestedWeightConverter under any HF prefix — # it lives at the HF root, not inside ``language_model.model``. return [ *cls.emit_weight_converters(config, "", ""), - *cls.language_model_converter_class.head_converter_class.get_converters(config, exported_config), + *cls.language_model_converter_class.head_converter_class.get_converters(config), ] From 998cbe83594adbb74cb5229b81c2e9dc995e35d1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 21 May 2026 15:21:42 -0400 Subject: [PATCH 06/12] Review-coarse cleanup pass on declarative-weight-converters PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Items 1-7 from the second /review-coarse pass: 1. Delete unused IgnoreImportWeightConverter / IgnoreExportWeightConverter (last callers were the removed drop_on_export plumbing). 2. Migrate Gemma4 to declarative weight converters — Gemma4Attention/MLP/ MoEMLP/HybridMoEMLP/Block/Decoder now inherit ConfigSectionConverter and define _create_weight_converters. The Gemma4-specific transforms (shared-K/V branching, mlp-type dispatch with divergent HF prefixes, conditional norm_2, two-level hybrid-MoE norm descent) live as small private WeightConverter subclasses next to the existing MoE layer converters. Config side stays imperative under CustomConfigConverter at the aggregator (Gemma4 sliding/full block divergence prevents a uniform per-block declarative shape); each helper carries a blanket IgnoredConfigConverter to silence the static walker. 3. Add optional=True to NestedWeightConverter and fold Apriel2 multimodal's vision_encoder back into _create_weight_converters (skip when None). 4. Fold Llava head into LlavaBaseModelConverter._create_weight_converters (NestedWeightConverter with empty hf_prefix; LlavaHead's leaf names are already absolute). 5. Move block_converter_class ClassVar from LlamaHeadConverter to its sole reader MTPLlamaHeadConverter. 6. Replace BlockSequenceWeightConverter's config_attr="" sentinel with an explicit read_self=True flag (2 callers updated). 7. Delete the four pass-only HeadConverter subclasses (Mistral, Mixtral, Qwen2, Apriel); the head_converter_class ClassVar inherits from LlamaBaseModelConverter, and LlavaHeadConverter rebases on LlamaHeadConverter directly. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 60 +-- fast_llm/models/gpt/conversion/apriel.py | 6 - fast_llm/models/gpt/conversion/gemma4.py | 448 +++++++++--------- fast_llm/models/gpt/conversion/llama.py | 6 +- fast_llm/models/gpt/conversion/mistral.py | 6 - fast_llm/models/gpt/conversion/mixtral.py | 6 - fast_llm/models/gpt/conversion/mtp_llama.py | 8 +- fast_llm/models/gpt/conversion/qwen2.py | 6 - .../models/multimodal/conversion/apriel2.py | 23 +- .../models/multimodal/conversion/llava.py | 22 +- 10 files changed, 265 insertions(+), 326 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 71138db68..398b47c8f 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -824,42 +824,6 @@ def _emit( ] -class IgnoreImportWeightConverter(WeightConverter): - def __post_init__(self): - Assert.eq(len(self.fast_llm_name), 0) - Assert.gt(len(self.export_name), 0) - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - raise RuntimeError( - f"IgnoreImportWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}" - ) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return () - - -class IgnoreExportWeightConverter(WeightConverter): - def __post_init__(self): - Assert.gt(len(self.fast_llm_name), 0) - Assert.eq(len(self.export_name), 0) - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return () - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - raise RuntimeError( - f"IgnoreExportWeightConverter should not be used for import: {self.fast_llm_name}, {self.export_name}" - ) - - class SplitWeightConverter(WeightConverter): def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] @@ -977,6 +941,9 @@ class NestedWeightConverter(WeightConverter): The separate ``config_attr`` covers cases like a block's single ``normalization`` config feeding two state-dict prefixes (``norm_1`` / ``norm_2``). + + ``optional=True`` skips the recursion when the resolved sub-config is ``None`` — for optional + architecture sections like Llava's ``vision_encoder``. """ def __init__( @@ -986,12 +953,14 @@ def __init__( sub_converter_class: type["ConfigSectionConverter"], *, config_attr: str | None = None, + optional: bool = False, ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix self._hf_prefix = hf_prefix self._sub_converter_class = sub_converter_class self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + self._optional = optional def _emit( self, @@ -1002,6 +971,8 @@ def _emit( root_config: Config, ) -> list[WeightConverter]: sub_config = getattr(config, self._config_attr) + if self._optional and sub_config is None: + return [] return self._sub_converter_class.emit_weight_converters( sub_config, _join_prefix(fast_llm_prefix, self._fast_llm_prefix), @@ -1020,12 +991,12 @@ class BlockSequenceWeightConverter(WeightConverter): Handles both ``FixedBlockSequenceConfig`` (single repeated block) and ``PatternBlockSequenceConfig`` (per-position blocks indexed via ``decoder.expanded_pattern``). - ``config_attr`` selects how the block sequence is reached from the parent config: + The block sequence is reached from the parent config in one of three ways: - * default (``None``) — read ``getattr(parent, fast_llm_prefix)``. - * explicit string — read ``getattr(parent, config_attr)``. - * empty string ``""`` — the *section* config is itself the block sequence (no parent attribute); - used when ``BlockSequenceWeightConverter`` is declared by a section converter whose + * default (``config_attr=None``) — read ``getattr(parent, fast_llm_prefix)``. + * explicit ``config_attr``-string — read ``getattr(parent, config_attr)``. + * ``read_self=True`` — the *section* config IS the block sequence (no parent attribute); used + when ``BlockSequenceWeightConverter`` is declared by a section converter whose ``fast_llm_config_class`` is a ``FixedBlockSequenceConfig`` directly (e.g. ``LlamaDecoderConverter`` plugged into the Pixtral vision encoder; Apriel2's vision encoder). """ @@ -1037,6 +1008,7 @@ def __init__( block_converter_class: type["ConfigSectionConverter"] | None = None, *, config_attr: str | None = None, + read_self: bool = False, dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]] | None = None, ): # Exactly one of the two must be set: the single-class path uses ``block_converter_class``; @@ -1046,10 +1018,14 @@ def __init__( lambda pair: (pair[0] is None) != (pair[1] is None), (block_converter_class, dispatch_registry), ) + # ``config_attr`` and ``read_self`` are mutually exclusive — one tells us how to descend to the + # block sequence; the other says we're already there. + Assert.custom(lambda pair: not (pair[0] is not None and pair[1]), (config_attr, read_self)) super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix self._hf_prefix = hf_prefix self._block_converter_class = block_converter_class + self._read_self = read_self self._config_attr = config_attr if config_attr is not None else fast_llm_prefix self._dispatch_registry = dispatch_registry @@ -1064,7 +1040,7 @@ def _emit( # Lazy import to keep external.py free of layers/ dependencies. from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - block_sequence = config if self._config_attr == "" else getattr(config, self._config_attr) + block_sequence = config if self._read_self else getattr(config, self._config_attr) if isinstance(block_sequence, FixedBlockSequenceConfig): per_position_blocks = [block_sequence.block] * block_sequence.num_blocks elif isinstance(block_sequence, PatternBlockSequenceConfig): diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 7f22db885..749b14e1c 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -28,7 +28,6 @@ from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) from fast_llm.utils import Assert, safe_merge_dicts @@ -400,10 +399,6 @@ class AprielBlockConverter: } -class AprielHeadConverter(MistralHeadConverter): - pass - - class AprielBaseModelConverter(MistralBaseModelConverter): """Section converter for the Apriel hybrid-SSM base model. @@ -412,7 +407,6 @@ class AprielBaseModelConverter(MistralBaseModelConverter): HF keys flat-merge into the parent HF root. """ - head_converter_class: typing.ClassVar[type[AprielHeadConverter]] = AprielHeadConverter # Distinct from the parent's ``block_converter_class`` (a single ``ConfigSectionConverter``); this # one holds the per-mixer-type dispatch registries that :class:`ListDispatchConfigConverter` and # the weight-side loop below consume. diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 752b0acdd..dc78cbf59 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -1,19 +1,25 @@ """Gemma4 checkpoint format converter.""" +import functools import typing +import torch + from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantExportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, - KeyValueWeightConverter, + LinearWeightConverter, + NestedWeightConverter, RenameConfigConverter, SplitWeightConverter, TransposeSplitWeightConverter, WeightConverter, + _join_prefix, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType @@ -37,41 +43,6 @@ from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts - -def _linear_converters( - fast_llm_prefix: str, - hf_prefix: str | tuple[str, ...], - use_bias: bool, - transform: type[WeightConverter] = WeightConverter, - config=None, -) -> list[WeightConverter]: - """Local helper: build ``.weight`` and (conditional) ``.bias`` converters for one linear layer. - - Gemma4's helper classes don't inherit ``ConfigSectionConverter``, so the - :class:`LinearWeightConverter` declarative primitive doesn't apply directly — this is the - smallest helper that covers the Gemma4-specific imperative ``get_converters`` shape. ``config`` - is forwarded to the transform constructor when the transform captures it - (e.g. :class:`KeyValueWeightConverter`). - """ - hf_names = (hf_prefix,) if isinstance(hf_prefix, str) else tuple(hf_prefix) - converters = [ - transform( - f"{fast_llm_prefix}.weight", - tuple(f"{name}.weight" for name in hf_names), - config, - ) - ] - if use_bias: - converters.append( - transform( - f"{fast_llm_prefix}.bias", - tuple(f"{name}.bias" for name in hf_names), - config, - ) - ) - return converters - - _SLIDING_ATTENTION = "sliding_attention" _FULL_ATTENTION = "full_attention" @@ -108,7 +79,110 @@ def import_weight(self, weight): return (w.permute(0, 2, 1).reshape(-1, w.shape[1]).contiguous(),) -class Gemma4AttentionConverter: +class _Gemma4BlockMLPWeightConverter(WeightConverter): + """Dispatch ``block.mlp`` to dense :class:`Gemma4MLPConverter` (under ``mlp.<...>``) or hybrid + :class:`Gemma4HybridMoEMLPConverter` (flat-merged into the block's HF root) based on the runtime + type of ``config.mlp``. The two targets diverge on HF prefix, which the generic + :class:`DispatchWeightConverter` doesn't accommodate. + """ + + def __init__(self) -> None: + super().__init__((), ()) + + def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): + fast_llm_mlp = _join_prefix(fast_llm_prefix, "mlp") + if isinstance(config.mlp, HybridMoEMLPConfig): + return Gemma4HybridMoEMLPConverter.emit_weight_converters( + config.mlp, fast_llm_mlp, hf_prefix, root_config=root_config + ) + return Gemma4MLPConverter.emit_weight_converters( + config.mlp, fast_llm_mlp, _join_prefix(hf_prefix, "mlp"), root_config=root_config + ) + + +class _Gemma4BlockNorm2WeightConverter(WeightConverter): + """Dense Gemma4 blocks store the pre-MLP norm at ``norm_2`` (drawn from the block's main + ``normalization`` config). MoE blocks suppress this — the routed/dense branches inside the + hybrid MoE own their own pre/post norms. + """ + + def __init__(self) -> None: + super().__init__((), ()) + + def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): + if isinstance(config.mlp, HybridMoEMLPConfig): + return [] + return LlamaNormalizationConverter.emit_weight_converters( + config.normalization, + _join_prefix(fast_llm_prefix, "norm_2"), + _join_prefix(hf_prefix, "pre_feedforward_layernorm"), + root_config=root_config, + ) + + +class _Gemma4HybridMoENormWeightConverter(WeightConverter): + """Emit a normalization config nested two attributes deep inside a hybrid MoE block + (e.g. ``config.dense.pre_norm``, ``config.routed.post_norm``). The single-level + :class:`NestedWeightConverter` can't express the chained descent — :class:`MLPConfig` and + :class:`MoEMLPConfig` each carry their own pre/post norms, but those branches live one level + below the hybrid MoE section root. Gemma4's hybrid MoE always sets these norms. + """ + + def __init__(self, branch: str, norm_attr: str, hf_name: str) -> None: + super().__init__((), ()) + self._branch = branch + self._norm_attr = norm_attr + self._hf_name = hf_name + + def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): + norm_config = getattr(getattr(config, self._branch), self._norm_attr) + return LlamaNormalizationConverter.emit_weight_converters( + norm_config, + _join_prefix(fast_llm_prefix, f"{self._branch}.{self._norm_attr}"), + _join_prefix(hf_prefix, self._hf_name), + root_config=root_config, + ) + + +class _Gemma4SharedKeyValueWeightConverter(WeightConverter): + """``shared_key_value=True`` Gemma4 attention: Fast-LLM's ``key_value`` is a single K-shaped + tensor (V is reused at runtime) and maps to a single HF ``k_proj`` — plain rename. Falls back to + :class:`KeyValueWeightConverter` (chunk/cat across K and V) when not shared. + """ + + _config: AttentionConfig + + def export_weight(self, weight): + if self._config.shared_key_value: + return weight + (key_value,) = weight + return key_value[:].chunk(2) + + def import_weight(self, weight): + if self._config.shared_key_value: + return weight + key, value = weight + return (torch.cat([key[:], value[:]]),) + + +class Gemma4AttentionConverter(ConfigSectionConverter): + """Gemma4's attention helper: ``import_config`` / ``export_config`` take non-standard arguments + (sliding/full discrimination, twin block exports) and are invoked imperatively from + :class:`Gemma4BlockConverter`. Only the weight side fits the standard declarative shape — biases + are always disabled, query/key norms are emitted only when present, and the K/V layout collapses + to a single ``k_proj`` when ``shared_key_value`` is set. + + The config side is owned by :class:`Gemma4BaseModelConverter`'s ``decoder`` :class:`CustomConfigConverter` + (with ``fast_llm_recurses=True``); the blanket-claim below silences the static architecture-coverage + walker — Gemma4's sliding/full divergence prevents a uniform declarative shape per single block. + """ + + fast_llm_config_class = AttentionConfig + + @classmethod + def _create_config_converters(cls) -> dict: + return {"_blanket": IgnoredConfigConverter(())} + @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: eps = config["rms_norm_eps"] @@ -187,57 +261,31 @@ def export_config(cls, sliding_config: AttentionConfig, full_config: AttentionCo } @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - ) -> list[WeightConverter]: - if config.shared_key_value: - # K=V: single k_proj reused as value; no v_proj in HF - kv_converters = _linear_converters( - f"{fast_llm_prefix}.key_value", - f"{hf_prefix}.k_proj", - False, - ) - else: - kv_converters = _linear_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - False, - KeyValueWeightConverter, - config, - ) - converters = [ - *_linear_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - False, - ), - *kv_converters, - *_linear_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - False, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "query": LinearWeightConverter("query", "q_proj", bias_fn=lambda c: False), + "key_value": LinearWeightConverter( + "key_value", + lambda c: "k_proj" if c.shared_key_value else ("k_proj", "v_proj"), + transform=_Gemma4SharedKeyValueWeightConverter, + bias_fn=lambda c: False, ), - ] - if config.query_norm is not None: - converters += LlamaNormalizationConverter.emit_weight_converters( - config.query_norm, - f"{fast_llm_prefix}.query_norm", - f"{hf_prefix}.q_norm", - ) - if config.key_norm is not None: - converters += LlamaNormalizationConverter.emit_weight_converters( - config.key_norm, - f"{fast_llm_prefix}.key_norm", - f"{hf_prefix}.k_norm", - ) - # value_norm is FixedRMSNorm — no learnable weight to convert - return converters + "dense": LinearWeightConverter("dense", "o_proj", bias_fn=lambda c: False), + # ``value_norm`` is :class:`FixedRMSNormConfig` (no learnable weight) — not declared. + "query_norm": NestedWeightConverter("query_norm", "q_norm", LlamaNormalizationConverter, optional=True), + "key_norm": NestedWeightConverter("key_norm", "k_norm", LlamaNormalizationConverter, optional=True), + } + + +class Gemma4MLPConverter(ConfigSectionConverter): + fast_llm_config_class = MLPConfig + @classmethod + def _create_config_converters(cls) -> dict: + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. + return {"_blanket": IgnoredConfigConverter(())} -class Gemma4MLPConverter: @classmethod def import_config(cls, config: dict) -> dict: return { @@ -260,29 +308,25 @@ def export_config(cls, config: MLPConfig) -> dict: } @classmethod - def get_converters( - cls, - config: MLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - ) -> list[WeightConverter]: - return [ - *_linear_converters( - f"{fast_llm_prefix}.layer_1", - (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), - False, - SplitWeightConverter, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter( + "layer_1", ("gate_proj", "up_proj"), transform=SplitWeightConverter, bias_fn=lambda c: False ), - *_linear_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - False, - TransposeSplitWeightConverter, + "layer_2": LinearWeightConverter( + "layer_2", "down_proj", transform=TransposeSplitWeightConverter, bias_fn=lambda c: False ), - ] + } -class Gemma4MoEMLPConverter: +class Gemma4MoEMLPConverter(ConfigSectionConverter): + fast_llm_config_class = MoEMLPConfig + + @classmethod + def _create_config_converters(cls) -> dict: + return {"_blanket": IgnoredConfigConverter(())} + @classmethod def import_config(cls, config: dict) -> dict: eps = config["rms_norm_eps"] @@ -329,28 +373,25 @@ def export_config(cls, config: MoEMLPConfig, hidden_size: int) -> dict: } @classmethod - def get_converters( - cls, - config: MoEMLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - ) -> list[WeightConverter]: - converters = [ - *_linear_converters( - f"{fast_llm_prefix}.router", - f"{hf_prefix}.router.proj", - False, - ), - WeightConverter(f"{fast_llm_prefix}.router_scale", f"{hf_prefix}.router.scale"), - WeightConverter(f"{fast_llm_prefix}.router_per_expert_scale", f"{hf_prefix}.router.per_expert_scale"), - Gemma4MoELayer1Converter(f"{fast_llm_prefix}.layer_1.weight", f"{hf_prefix}.experts.gate_up_proj", config), - Gemma4MoELayer2Converter(f"{fast_llm_prefix}.layer_2.weight", f"{hf_prefix}.experts.down_proj", config), - ] - # router.norm is FixedRMSNorm — no learnable weight to convert. - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``router.norm`` is :class:`FixedRMSNormConfig` (no learnable weight) — not declared. + return { + "router": LinearWeightConverter("router", "router.proj", bias_fn=lambda c: False), + "router_scale": WeightConverter("router_scale", "router.scale"), + "router_per_expert_scale": WeightConverter("router_per_expert_scale", "router.per_expert_scale"), + "layer_1": Gemma4MoELayer1Converter("layer_1.weight", "experts.gate_up_proj"), + "layer_2": Gemma4MoELayer2Converter("layer_2.weight", "experts.down_proj"), + } + +class Gemma4HybridMoEMLPConverter(ConfigSectionConverter): + fast_llm_config_class = HybridMoEMLPConfig + + @classmethod + def _create_config_converters(cls) -> dict: + return {"_blanket": IgnoredConfigConverter(())} -class Gemma4HybridMoEMLPConverter: @classmethod def import_config(cls, config: dict) -> dict: eps = config["rms_norm_eps"] @@ -376,47 +417,32 @@ def export_config(cls, config: HybridMoEMLPConfig, hidden_size: int) -> dict: ) @classmethod - def get_converters( - cls, - config: HybridMoEMLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - ) -> list[WeightConverter]: - return [ - *Gemma4MLPConverter.get_converters( - config.dense, - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.mlp", - ), - *Gemma4MoEMLPConverter.get_converters( - config.routed, - f"{fast_llm_prefix}.routed", - hf_prefix, - ), - *LlamaNormalizationConverter.emit_weight_converters( - config.dense.pre_norm, - f"{fast_llm_prefix}.dense.pre_norm", - f"{hf_prefix}.pre_feedforward_layernorm", - ), - *LlamaNormalizationConverter.emit_weight_converters( - config.dense.post_norm, - f"{fast_llm_prefix}.dense.post_norm", - f"{hf_prefix}.post_feedforward_layernorm_1", + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "dense": NestedWeightConverter("dense", "mlp", Gemma4MLPConverter), + # Routed branch lives at the block's HF root (sibling of ``mlp.<...>``). + "routed": NestedWeightConverter("routed", "", Gemma4MoEMLPConverter), + "dense_pre_norm": _Gemma4HybridMoENormWeightConverter("dense", "pre_norm", "pre_feedforward_layernorm"), + "dense_post_norm": _Gemma4HybridMoENormWeightConverter( + "dense", "post_norm", "post_feedforward_layernorm_1" ), - *LlamaNormalizationConverter.emit_weight_converters( - config.routed.pre_norm, - f"{fast_llm_prefix}.routed.pre_norm", - f"{hf_prefix}.pre_feedforward_layernorm_2", + "routed_pre_norm": _Gemma4HybridMoENormWeightConverter( + "routed", "pre_norm", "pre_feedforward_layernorm_2" ), - *LlamaNormalizationConverter.emit_weight_converters( - config.routed.post_norm, - f"{fast_llm_prefix}.routed.post_norm", - f"{hf_prefix}.post_feedforward_layernorm_2", + "routed_post_norm": _Gemma4HybridMoENormWeightConverter( + "routed", "post_norm", "post_feedforward_layernorm_2" ), - ] + } + + +class Gemma4BlockConverter(ConfigSectionConverter): + fast_llm_config_class = DecoderBlockConfig + @classmethod + def _create_config_converters(cls) -> dict: + return {"_blanket": IgnoredConfigConverter(())} -class Gemma4BlockConverter: @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: def make_norm(): @@ -465,61 +491,42 @@ def export_config( return out @classmethod - def get_converters( - cls, - config: DecoderBlockConfig, - fast_llm_prefix: str, - hf_prefix: str, - ) -> list[WeightConverter]: - is_moe = isinstance(config.mlp, HybridMoEMLPConfig) - converters = [ - *Gemma4AttentionConverter.get_converters( - config.mixer, - f"{fast_llm_prefix}.mixer", - f"{hf_prefix}.self_attn", - ), - ] - if is_moe: - converters += Gemma4HybridMoEMLPConverter.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - hf_prefix, - ) - else: - converters += Gemma4MLPConverter.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - f"{hf_prefix}.mlp", - ) - converters += LlamaNormalizationConverter.emit_weight_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.pre_feedforward_layernorm", - ) - converters += [ - *LlamaNormalizationConverter.emit_weight_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.input_layernorm", + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": NestedWeightConverter("mixer", "self_attn", Gemma4AttentionConverter), + "mlp": _Gemma4BlockMLPWeightConverter(), + "norm_1": NestedWeightConverter( + "norm_1", "input_layernorm", LlamaNormalizationConverter, config_attr="normalization" ), - *LlamaNormalizationConverter.emit_weight_converters( - config.post_mixer_normalization, - f"{fast_llm_prefix}.post_mixer_norm", - f"{hf_prefix}.post_attention_layernorm", + "norm_2": _Gemma4BlockNorm2WeightConverter(), + "post_mixer_norm": NestedWeightConverter( + "post_mixer_norm", + "post_attention_layernorm", + LlamaNormalizationConverter, + config_attr="post_mixer_normalization", ), - *LlamaNormalizationConverter.emit_weight_converters( - config.post_mlp_normalization, - f"{fast_llm_prefix}.post_mlp_norm", - f"{hf_prefix}.post_feedforward_layernorm", + "post_mlp_norm": NestedWeightConverter( + "post_mlp_norm", + "post_feedforward_layernorm", + LlamaNormalizationConverter, + config_attr="post_mlp_normalization", ), - ] - converters.append(WeightConverter(f"{fast_llm_prefix}.output_scale", f"{hf_prefix}.layer_scalar")) - return converters + # HF stores ``layer_scalar`` as a non-trained buffer; Fast-LLM mirrors it with a frozen + # ``output_scale`` (``lr_scale=0``). + "output_scale": WeightConverter("output_scale", "layer_scalar"), + } -class Gemma4DecoderConverter: +class Gemma4DecoderConverter(ConfigSectionConverter): + fast_llm_config_class = PatternBlockSequenceConfig + block_converter_class: typing.ClassVar[type[Gemma4BlockConverter]] = Gemma4BlockConverter + @classmethod + def _create_config_converters(cls) -> dict: + return {"_blanket": IgnoredConfigConverter(())} + @classmethod def import_config(cls, config: dict) -> dict: layer_types = config["layer_types"] @@ -553,22 +560,11 @@ def export_config(cls, config: PatternBlockSequenceConfig, hidden_size: int) -> ) @classmethod - def get_converters( - cls, - config: PatternBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - ) -> list[WeightConverter]: - Assert.custom(isinstance, config, PatternBlockSequenceConfig) - converters = [] - for block_index in range(config.num_blocks): - block_config = config.blocks[config.expanded_pattern[block_index]] - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, read_self=True), + } class Gemma4EmbeddingsConverter(LlamaEmbeddingsConverter): @@ -737,7 +733,9 @@ def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: *cls.embeddings_converter_class.emit_weight_converters( config.embeddings, "embeddings", "model", root_config=config ), - *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), + *cls.decoder_converter_class.emit_weight_converters( + config.decoder, "decoder", "model.layers", root_config=config + ), *cls.head_converter_class.get_converters(config), ] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index de71a13c6..0e92bceca 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -363,11 +363,11 @@ def _create_config_converters(cls) -> dict: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it) — - # ``config_attr=""`` tells ``BlockSequenceWeightConverter`` to read the section config directly. + # ``read_self=True`` tells ``BlockSequenceWeightConverter`` to read the section config directly. # Used by Pixtral's vision encoder and Apriel2's vision encoder; text formats inline the dispatch # at the base-model converter instead. return { - "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, config_attr=""), + "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, read_self=True), } @@ -405,8 +405,6 @@ class LlamaHeadConverter(ConfigSectionConverter): fast_llm_config_class = LanguageModelHeadConfig normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter - # Used by MTP-Llama subclass to emit per-prediction-head block weight converters; Llama itself doesn't read it. - block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod def _create_config_converters(cls) -> dict: diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index ac42b4bf9..a9b77fde5 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -7,7 +7,6 @@ LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, ) @@ -39,13 +38,8 @@ class MistralBlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter -class MistralHeadConverter(LlamaHeadConverter): - pass - - class MistralBaseModelConverter(LlamaBaseModelConverter): block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter - head_converter_class: typing.ClassVar[type[MistralHeadConverter]] = MistralHeadConverter class MistralHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 00b6f28ff..4d377dcf4 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -17,7 +17,6 @@ from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) from fast_llm.utils import Assert @@ -83,13 +82,8 @@ class MixtralBlockConverter(MistralBlockConverter): mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter -class MixtralHeadConverter(MistralHeadConverter): - pass - - class MixtralBaseModelConverter(MistralBaseModelConverter): block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter - head_converter_class: typing.ClassVar[type[MixtralHeadConverter]] = MixtralHeadConverter class MixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index d76e260da..4513b5bd4 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -14,6 +14,7 @@ from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, + LlamaBlockConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, ) @@ -21,6 +22,9 @@ class MTPLlamaHeadConverter(LlamaHeadConverter): + # The MTP block shape matches the main decoder block, so we plug ``LlamaBlockConverter`` in directly. + block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + @classmethod def _create_config_converters(cls) -> dict: return { @@ -48,9 +52,7 @@ def get_converters( config: GPTBaseModelConfig, ) -> list[WeightConverter]: converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config)) - # Append the MTP fan-out: one block + one norm per extra prediction head. ``block_converter_class`` - # comes from the parent ``LlamaHeadConverter`` ClassVar — the MTP block shape matches the main - # decoder block. + # Append the MTP fan-out: one block + one norm per extra prediction head. for prediction_distance in range(2, config.head.prediction_heads + 1): converters += cls.block_converter_class.emit_weight_converters( config.decoder.last_block_config, diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index d932c4f5e..b5f4bf0dd 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -18,7 +18,6 @@ LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, ) @@ -101,10 +100,6 @@ class Qwen2BlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter -class Qwen2HeadConverter(LlamaHeadConverter): - pass - - def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: if hf_dict.get("use_mrope") is True: raise NotImplementedError("MRoPE (use_mrope=True) is not supported by the Qwen2 converter") @@ -113,7 +108,6 @@ def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: class Qwen2BaseModelConverter(LlamaBaseModelConverter): block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter - head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter @classmethod def _create_config_converters(cls) -> dict: diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index f96f3f365..d3f597a65 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -253,10 +253,10 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # The section config IS the FixedBlockSequenceConfig — ``config_attr=""`` makes + # The section config IS the FixedBlockSequenceConfig — ``read_self=True`` makes # BlockSequenceWeightConverter read the section config directly instead of via ``getattr``. return { - "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, config_attr=""), + "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, read_self=True), } @@ -485,6 +485,11 @@ def _vision_import(hf_dict: dict) -> dict: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { + # Vision encoder is optional — ``optional=True`` skips the recursion when + # ``config.vision_encoder is None`` (text-only checkpoints). + "vision_encoder": NestedWeightConverter( + "vision_encoder", "", cls.vision_model_converter_class, optional=True + ), # ``embeddings`` flat-merges into ``model``; vision_encoder writes to its own absolute prefix. "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", cls.block_converter_class), @@ -493,19 +498,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: @classmethod def get_converters(cls, config: MultiModalBaseModelConfig) -> list[WeightConverter]: - converters = list(cls.emit_weight_converters(config, "", "")) - if config.vision_encoder is not None: - # Vision encoder is optional — emit only when present. The Nested declaration in - # :meth:`_create_weight_converters` couldn't conditionally fire on a None attribute. - converters = ( - list( - cls.vision_model_converter_class.emit_weight_converters( - config.vision_encoder, "vision_encoder", "", root_config=config - ) - ) - + converters - ) - return converters + return cls.emit_weight_converters(config, "", "") class Apriel2HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 915a93ea9..1df64465c 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -30,9 +30,10 @@ LlamaAttentionConverter, LlamaBlockConverter, LlamaDecoderConverter, + LlamaHeadConverter, LlamaNormalizationConverter, ) -from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralHeadConverter, MistralMLPConverter +from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralMLPConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel @@ -164,10 +165,8 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: bias_fn=lambda c: False, ), # ``PixtralEmbeddingsConverter``'s section config IS the ``PatchEmbeddingsConfig`` (carries the - # normalization sub-config directly), so the nested ``LlamaNormalizationConverter`` reads from - # ``config_attr="normalization"`` — but the original code passed the *parent* config in. Mirror - # that by reading ``self`` (config_attr=""): the norm converter only needs ``.weight`` and the - # parent already exposes that field directly. + # ``normalization`` sub-config directly), so the nested ``LlamaNormalizationConverter`` reads + # ``getattr(section_config, "normalization")``. "normalization": NestedWeightConverter( "normalization", "ln_pre", cls.normalization_converter_class, config_attr="normalization" ), @@ -291,7 +290,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: } -class LlavaHeadConverter(MistralHeadConverter): +class LlavaHeadConverter(LlamaHeadConverter): # Llava always writes ``lm_head.weight`` on export (never dropped, even when ``tied_embedding_weight=True``); # the parent's :class:`OutputProjectionWeightConverter` would also drop on export, so we replace it with a # plain rename. When the HF state-dict lacks ``lm_head.weight`` (tied case), the handler's per-converter @@ -400,17 +399,14 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "decoder": BlockSequenceWeightConverter( "decoder", "language_model.model.layers", text_base_cls.block_converter_class ), + # ``LlavaHeadConverter``'s leaf converters use absolute HF paths (``language_model.lm_head.weight``, + # ``language_model.model.norm``), so an empty ``hf_prefix`` lets them land verbatim. + "head": NestedWeightConverter("head", "", text_base_cls.head_converter_class), } @classmethod def get_converters(cls, config: MultiModalBaseModelConfig) -> list[WeightConverter]: - # ``head`` is added at the aggregator level because the LlavaHead's plain WeightConverter for - # ``language_model.lm_head.weight`` doesn't fit a NestedWeightConverter under any HF prefix — - # it lives at the HF root, not inside ``language_model.model``. - return [ - *cls.emit_weight_converters(config, "", ""), - *cls.language_model_converter_class.head_converter_class.get_converters(config), - ] + return cls.emit_weight_converters(config, "", "") class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): From 91af72dfc6cfc22e67aee6a01fb0855c234de4f0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 25 May 2026 12:15:21 -0400 Subject: [PATCH 07/12] Second /review-coarse pass: structural cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Items 1-4 from the third /review-coarse pass: 1. Split ``BlockSequenceWeightConverter`` into three flat primitives — single-class (``BlockSequenceWeightConverter``), per-position dispatch (``DispatchBlockSequenceWeightConverter``), and section-IS-the-block-sequence (``SelfBlockSequenceWeightConverter``). Drops the dual XOR ``Assert.custom`` switchboard. Shared list-materialization extracted to ``_expand_block_sequence``. 2. Generalize two framework primitives so two of Gemma4's four private one-offs fold in: * ``DispatchWeightConverter`` gains ``hf_prefix_overrides`` for per-branch HF paths (Gemma4's block.mlp dispatch where dense lands under ``mlp.<...>`` and hybrid MoE flat-merges into the block root). * ``NestedWeightConverter.config_attr`` accepts tuple/dotted paths for chained ``getattr`` (Gemma4's hybrid-MoE inner norms via ``("dense", "pre_norm")``). Rename ``_join_prefix`` and ``_prepend_prefix`` to drop the underscore — now public utilities used by Gemma4's remaining two one-offs. 3. Lift the one-line ``cls.emit_weight_converters(config, "", "")`` passthrough into ``HuggingFaceBaseModelConverter.get_converters`` as a concrete default. Apriel2 (text), Apriel2 multimodal, and Llava lose their overrides. Apriel2BaseModelConverter now multi-inherits ``HuggingFaceBaseModelConverter`` so it picks up the default. Llama, Gemma4, MTP-Llama keep their overrides — they splice ``head_converter_class.get_converters(config)`` separately because the head needs the full ``GPTBaseModelConfig`` (MTP-Llama reads ``config.decoder.last_block_config`` for per-prediction-head fan-out). 4. ``AprielBlockConverter`` docstring: ``get_converters`` was removed in the prior cleanup pass; update the docstring to describe the class as a registry holder consumed by ``ListDispatchConfigConverter`` (config side) and ``DispatchBlockSequenceWeightConverter`` (weight side). Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/checkpoint/external.py | 212 ++++++++++++------ fast_llm/engine/checkpoint/huggingface.py | 9 +- fast_llm/models/gpt/conversion/apriel.py | 18 +- fast_llm/models/gpt/conversion/apriel2.py | 8 +- fast_llm/models/gpt/conversion/gemma4.py | 95 +++----- fast_llm/models/gpt/conversion/llama.py | 10 +- .../models/multimodal/conversion/apriel2.py | 11 +- .../models/multimodal/conversion/llava.py | 4 - 8 files changed, 202 insertions(+), 165 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 398b47c8f..27a2ceea5 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -754,14 +754,14 @@ def check_architecture_coverage(cls, config: Config) -> None: ) -def _prepend_prefix(prefix: str, names: tuple[str, ...]) -> tuple[str, ...]: +def prepend_prefix(prefix: str, names: tuple[str, ...]) -> tuple[str, ...]: """Prepend ``prefix`` to each name. Empty ``prefix`` is a no-op; empty ``names`` (drop side) stays empty.""" if not prefix: return names return tuple(f"{prefix}.{name}" for name in names) -def _join_prefix(parent: str, own: str) -> str: +def join_prefix(parent: str, own: str) -> str: """Join two dot-separated prefixes, tolerating either being empty. Structural primitives (Nested/BlockSequence/Dispatch/TypedDict) call this when building the @@ -817,8 +817,8 @@ def _emit( """ return [ type(self)( - _prepend_prefix(fast_llm_prefix, self.fast_llm_name), - _prepend_prefix(hf_prefix, self.export_name), + prepend_prefix(fast_llm_prefix, self.fast_llm_name), + prepend_prefix(hf_prefix, self.export_name), config, ) ] @@ -934,13 +934,16 @@ def _emit( class NestedWeightConverter(WeightConverter): """Recurse into a sub-section's weight declarations. - The sub-section's config is read from ``getattr(config, config_attr)`` (defaults to ``fast_llm_prefix`` - when the state-dict prefix and the parent's attribute name agree). The walker descends into - ``sub_converter_class._create_weight_converters()`` with extended prefixes. Mirrors - :class:`NestedConfigConverter` on the config side. + The sub-section's config is read by chained ``getattr`` from ``config`` via ``config_attr``: - The separate ``config_attr`` covers cases like a block's single ``normalization`` config feeding two - state-dict prefixes (``norm_1`` / ``norm_2``). + * ``None`` (default) — single attribute named after ``fast_llm_prefix`` (e.g. ``getattr(config, "mixer")``). + * single string — single attribute (covers a block's ``normalization`` config feeding two state-dict + prefixes ``norm_1`` / ``norm_2``). + * tuple of strings — chained descent (e.g. ``("dense", "pre_norm")`` ↦ ``config.dense.pre_norm`` for + Gemma4's hybrid-MoE inner norms). + + The walker descends into ``sub_converter_class._create_weight_converters()`` with extended prefixes. + Mirrors :class:`NestedConfigConverter` on the config side. ``optional=True`` skips the recursion when the resolved sub-config is ``None`` — for optional architecture sections like Llava's ``vision_encoder``. @@ -952,14 +955,19 @@ def __init__( hf_prefix: str, sub_converter_class: type["ConfigSectionConverter"], *, - config_attr: str | None = None, + config_attr: str | tuple[str, ...] | None = None, optional: bool = False, ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix self._hf_prefix = hf_prefix self._sub_converter_class = sub_converter_class - self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + if config_attr is None: + self._config_attrs: tuple[str, ...] = (fast_llm_prefix,) + elif isinstance(config_attr, str): + self._config_attrs = (config_attr,) + else: + self._config_attrs = config_attr self._optional = optional def _emit( @@ -970,64 +978,57 @@ def _emit( *, root_config: Config, ) -> list[WeightConverter]: - sub_config = getattr(config, self._config_attr) + sub_config: typing.Any = config + for attr in self._config_attrs: + sub_config = getattr(sub_config, attr) if self._optional and sub_config is None: return [] return self._sub_converter_class.emit_weight_converters( sub_config, - _join_prefix(fast_llm_prefix, self._fast_llm_prefix), - _join_prefix(hf_prefix, self._hf_prefix), + join_prefix(fast_llm_prefix, self._fast_llm_prefix), + join_prefix(hf_prefix, self._hf_prefix), root_config=root_config, ) -class BlockSequenceWeightConverter(WeightConverter): - """Fan out a per-block sub-section across every position in a block sequence. +def _expand_block_sequence(block_sequence: Config) -> list[Config]: + """Materialize a ``Fixed``/``PatternBlockSequenceConfig`` into a per-position list of block configs. + + ``FixedBlockSequenceConfig``: single repeated block. ``PatternBlockSequenceConfig``: per-position + blocks indexed via ``decoder.expanded_pattern``. + """ + # Lazy import to keep external.py free of layers/ dependencies. + from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - The sub-section's converter class is resolved per-position from ``block_converter_class``: by default, - the same class for every position; when ``dispatch_registry`` is provided, the per-position config is - matched against the registry keys (Apriel's hybrid-block dispatch — different mixer types per layer). + if isinstance(block_sequence, FixedBlockSequenceConfig): + return [block_sequence.block] * block_sequence.num_blocks + if isinstance(block_sequence, PatternBlockSequenceConfig): + return [block_sequence.blocks[name] for name in block_sequence.expanded_pattern] + raise NotImplementedError(type(block_sequence).__name__) - Handles both ``FixedBlockSequenceConfig`` (single repeated block) and ``PatternBlockSequenceConfig`` - (per-position blocks indexed via ``decoder.expanded_pattern``). - The block sequence is reached from the parent config in one of three ways: +class BlockSequenceWeightConverter(WeightConverter): + """Fan a single block converter across every position in a block sequence reached via attribute access. - * default (``config_attr=None``) — read ``getattr(parent, fast_llm_prefix)``. - * explicit ``config_attr``-string — read ``getattr(parent, config_attr)``. - * ``read_self=True`` — the *section* config IS the block sequence (no parent attribute); used - when ``BlockSequenceWeightConverter`` is declared by a section converter whose - ``fast_llm_config_class`` is a ``FixedBlockSequenceConfig`` directly (e.g. - ``LlamaDecoderConverter`` plugged into the Pixtral vision encoder; Apriel2's vision encoder). + Reads the block sequence from ``getattr(parent, config_attr)`` (defaults to ``fast_llm_prefix``). + For per-position type dispatch (different mixer types per layer) use + :class:`DispatchBlockSequenceWeightConverter` instead; when the section config IS the block sequence + use :class:`SelfBlockSequenceWeightConverter`. """ def __init__( self, fast_llm_prefix: str, hf_prefix: str, - block_converter_class: type["ConfigSectionConverter"] | None = None, + block_converter_class: type["ConfigSectionConverter"], *, config_attr: str | None = None, - read_self: bool = False, - dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]] | None = None, ): - # Exactly one of the two must be set: the single-class path uses ``block_converter_class``; - # the per-position-type-dispatch path uses ``dispatch_registry``. Passing both would silently - # ignore ``block_converter_class`` since ``_emit`` prefers the registry. - Assert.custom( - lambda pair: (pair[0] is None) != (pair[1] is None), - (block_converter_class, dispatch_registry), - ) - # ``config_attr`` and ``read_self`` are mutually exclusive — one tells us how to descend to the - # block sequence; the other says we're already there. - Assert.custom(lambda pair: not (pair[0] is not None and pair[1]), (config_attr, read_self)) super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix self._hf_prefix = hf_prefix self._block_converter_class = block_converter_class - self._read_self = read_self self._config_attr = config_attr if config_attr is not None else fast_llm_prefix - self._dispatch_registry = dispatch_registry def _emit( self, @@ -1037,26 +1038,53 @@ def _emit( *, root_config: Config, ) -> list[WeightConverter]: - # Lazy import to keep external.py free of layers/ dependencies. - from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig - - block_sequence = config if self._read_self else getattr(config, self._config_attr) - if isinstance(block_sequence, FixedBlockSequenceConfig): - per_position_blocks = [block_sequence.block] * block_sequence.num_blocks - elif isinstance(block_sequence, PatternBlockSequenceConfig): - per_position_blocks = [block_sequence.blocks[name] for name in block_sequence.expanded_pattern] - else: - raise NotImplementedError(type(block_sequence).__name__) - - fast_llm_root = _join_prefix(fast_llm_prefix, self._fast_llm_prefix) - hf_root = _join_prefix(hf_prefix, self._hf_prefix) + fast_llm_root = join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = join_prefix(hf_prefix, self._hf_prefix) out: list[WeightConverter] = [] - for index, block in enumerate(per_position_blocks): - block_class = ( - self._dispatch_registry[type(block.mixer)] - if self._dispatch_registry is not None - else self._block_converter_class + for index, block in enumerate(_expand_block_sequence(getattr(config, self._config_attr))): + out += self._block_converter_class.emit_weight_converters( + block, + f"{fast_llm_root}.{index}", + f"{hf_root}.{index}", + root_config=root_config, ) + return out + + +class DispatchBlockSequenceWeightConverter(WeightConverter): + """Fan a per-position-dispatched block converter across every position in a block sequence. + + Each position's block config is matched against ``dispatch_registry`` keys by its mixer type + (``type(block.mixer)``) — Apriel's hybrid-block dispatch. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._dispatch_registry = dispatch_registry + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + fast_llm_root = join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = join_prefix(hf_prefix, self._hf_prefix) + out: list[WeightConverter] = [] + for index, block in enumerate(_expand_block_sequence(getattr(config, self._config_attr))): + block_class = self._dispatch_registry[type(block.mixer)] out += block_class.emit_weight_converters( block, f"{fast_llm_root}.{index}", @@ -1066,6 +1094,37 @@ def _emit( return out +class SelfBlockSequenceWeightConverter(WeightConverter): + """Fan a single block converter across the section config when *the section IS the block sequence*. + + Used when the declaring section's ``fast_llm_config_class`` is itself a ``FixedBlockSequenceConfig`` + or ``PatternBlockSequenceConfig`` (e.g. ``LlamaDecoderConverter`` plugged into the Pixtral vision + encoder; Apriel2 and Gemma4's decoders). Weights land directly under the section's outer prefixes. + """ + + def __init__(self, block_converter_class: type["ConfigSectionConverter"]): + super().__init__((), ()) + self._block_converter_class = block_converter_class + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + out: list[WeightConverter] = [] + for index, block in enumerate(_expand_block_sequence(config)): + out += self._block_converter_class.emit_weight_converters( + block, + f"{fast_llm_prefix}.{index}", + f"{hf_prefix}.{index}", + root_config=root_config, + ) + return out + + class DispatchWeightConverter(WeightConverter): """Dispatch a single sub-section converter based on the live config's runtime type. @@ -1074,6 +1133,9 @@ class DispatchWeightConverter(WeightConverter): Mirrors :class:`DispatchConfigConverter` on the config side. Used when a single attribute holds one of several alternative configs (e.g. Apriel2's block ``mixer`` may be attention/mamba/gdn/kda/stochastic; its ``normalization`` may be RMS/Layer/None). + + ``hf_prefix_overrides`` lets individual branches replace the shared ``hf_prefix`` (e.g. Gemma4's + hybrid MoE flat-merges into the block root while dense MLP lands under ``mlp.<...>``). """ def __init__( @@ -1083,12 +1145,14 @@ def __init__( registry: dict[type[Config], type["ConfigSectionConverter"]], *, config_attr: str | None = None, + hf_prefix_overrides: dict[type[Config], str] | None = None, ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix self._hf_prefix = hf_prefix self._registry = registry self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + self._hf_prefix_overrides = hf_prefix_overrides or {} def _emit( self, @@ -1099,11 +1163,13 @@ def _emit( root_config: Config, ) -> list[WeightConverter]: sub_config = getattr(config, self._config_attr) - sub_class = self._registry[type(sub_config)] + sub_type = type(sub_config) + sub_class = self._registry[sub_type] + own_hf_prefix = self._hf_prefix_overrides.get(sub_type, self._hf_prefix) return sub_class.emit_weight_converters( sub_config, - _join_prefix(fast_llm_prefix, self._fast_llm_prefix), - _join_prefix(hf_prefix, self._hf_prefix), + join_prefix(fast_llm_prefix, self._fast_llm_prefix), + join_prefix(hf_prefix, own_hf_prefix), root_config=root_config, ) @@ -1140,15 +1206,15 @@ def _emit( root_config: Config, ) -> list[WeightConverter]: attr_dict = getattr(config, self._config_attr) - fast_llm_root = _join_prefix(fast_llm_prefix, self._fast_llm_prefix) - hf_root = _join_prefix(hf_prefix, self._hf_prefix) + fast_llm_root = join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = join_prefix(hf_prefix, self._hf_prefix) out: list[WeightConverter] = [] for name, sub_config in attr_dict.items(): sub_class = self._registry[type(sub_config)] out += sub_class.emit_weight_converters( sub_config, - _join_prefix(fast_llm_root, name), - _join_prefix(hf_root, name), + join_prefix(fast_llm_root, name), + join_prefix(hf_root, name), root_config=root_config, ) return out @@ -1195,12 +1261,12 @@ def _emit( ) -> list[WeightConverter]: resolved = self._hf_prefix(config) if callable(self._hf_prefix) else self._hf_prefix hf_prefixes: tuple[str, ...] = (resolved,) if isinstance(resolved, str) else tuple(resolved) - weight_fast_llm = _prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.weight",)) - weight_hf = _prepend_prefix(hf_prefix, tuple(f"{p}.weight" for p in hf_prefixes)) + weight_fast_llm = prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.weight",)) + weight_hf = prepend_prefix(hf_prefix, tuple(f"{p}.weight" for p in hf_prefixes)) emitted: list[WeightConverter] = [self._transform(weight_fast_llm, weight_hf, config)] if self._bias_fn(config): - bias_fast_llm = _prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.bias",)) - bias_hf = _prepend_prefix(hf_prefix, tuple(f"{p}.bias" for p in hf_prefixes)) + bias_fast_llm = prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.bias",)) + bias_hf = prepend_prefix(hf_prefix, tuple(f"{p}.bias" for p in hf_prefixes)) emitted.append(self._transform(bias_fast_llm, bias_hf, config)) return emitted diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 864a21300..59b10e389 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -35,9 +35,14 @@ def export_config(cls, config: BaseModelConfig) -> dict: pass @classmethod - @abc.abstractmethod def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: - pass + """Default: walk the section's weight declarations from the root. + + Subclasses with constructs that don't fit the standard declaration walk override — e.g. + :class:`LlamaBaseModelConverter` splices the head's weights separately so MTP-Llama's + per-prediction-head fan-out has access to the full base-model config. + """ + return cls.emit_weight_converters(config, "", "") # type: ignore[attr-defined] class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 749b14e1c..8ec56349c 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -7,11 +7,11 @@ from fast_llm.config import Config, get_nested_dict_value from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( - BlockSequenceWeightConverter, ConfigConverter, ConfigSectionConverter, CustomConfigConverter, DefaultConfigConverter, + DispatchBlockSequenceWeightConverter, IgnoredConfigConverter, LinearWeightConverter, RenameConfigConverter, @@ -376,14 +376,14 @@ def _consumed_hf_paths(self) -> frozenset[tuple[str, ...]]: class AprielBlockConverter: - """Per-mixer-type Apriel block-converter registries plus a weight-side dispatch helper. + """Registry holder for Apriel's per-mixer-type block converters. ``layout_names`` maps the mixer-config classes that participate in Apriel's - ``hybrid_block_layout`` discriminator to their string layout names. ``_converter_classes`` - maps every mixer-config class whose weights can appear in an Apriel checkpoint to its block - converter — a superset of ``layout_names`` keys that adds ``KimiDeltaAttentionConfig`` for - weight-only coverage. :meth:`get_converters` picks the right block converter from - ``_converter_classes`` by the mixer's runtime type. + ``hybrid_block_layout`` discriminator to their string layout names — consumed by + :class:`ListDispatchConfigConverter` on the config side. ``_converter_classes`` maps every + mixer-config class whose weights can appear in an Apriel checkpoint to its block converter + (a superset of ``layout_names`` keys that adds ``KimiDeltaAttentionConfig`` for weight-only + coverage) — consumed by ``DispatchBlockSequenceWeightConverter`` on the weight side. """ layout_names: typing.ClassVar[dict[type[Config], str]] = { @@ -432,10 +432,10 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: # the right block converter from the dispatcher's registry based on the mixer's runtime type. return { **super()._create_weight_converters(), - "decoder": BlockSequenceWeightConverter( + "decoder": DispatchBlockSequenceWeightConverter( "decoder", "model.layers", - dispatch_registry=cls.block_dispatcher_class._converter_classes, + cls.block_dispatcher_class._converter_classes, ), } diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index a6f51104c..f546eac8e 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -28,7 +28,7 @@ TypedDictWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig @@ -689,7 +689,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: } -class Apriel2BaseModelConverter(ConfigSectionConverter): +class Apriel2BaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): fast_llm_config_class = GPTBaseModelConfig embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter @@ -731,10 +731,6 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "head": NestedWeightConverter("head", "", cls.head_converter_class), } - @classmethod - def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: - return cls.emit_weight_converters(config, "", "") - class Apriel2HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: GPTModel diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index dc78cbf59..c511d624e 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -8,18 +8,19 @@ from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( - BlockSequenceWeightConverter, ConfigSectionConverter, ConstantExportConfigConverter, CustomConfigConverter, + DispatchWeightConverter, IgnoredConfigConverter, LinearWeightConverter, NestedWeightConverter, RenameConfigConverter, + SelfBlockSequenceWeightConverter, SplitWeightConverter, TransposeSplitWeightConverter, WeightConverter, - _join_prefix, + join_prefix, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType @@ -79,31 +80,11 @@ def import_weight(self, weight): return (w.permute(0, 2, 1).reshape(-1, w.shape[1]).contiguous(),) -class _Gemma4BlockMLPWeightConverter(WeightConverter): - """Dispatch ``block.mlp`` to dense :class:`Gemma4MLPConverter` (under ``mlp.<...>``) or hybrid - :class:`Gemma4HybridMoEMLPConverter` (flat-merged into the block's HF root) based on the runtime - type of ``config.mlp``. The two targets diverge on HF prefix, which the generic - :class:`DispatchWeightConverter` doesn't accommodate. - """ - - def __init__(self) -> None: - super().__init__((), ()) - - def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): - fast_llm_mlp = _join_prefix(fast_llm_prefix, "mlp") - if isinstance(config.mlp, HybridMoEMLPConfig): - return Gemma4HybridMoEMLPConverter.emit_weight_converters( - config.mlp, fast_llm_mlp, hf_prefix, root_config=root_config - ) - return Gemma4MLPConverter.emit_weight_converters( - config.mlp, fast_llm_mlp, _join_prefix(hf_prefix, "mlp"), root_config=root_config - ) - - class _Gemma4BlockNorm2WeightConverter(WeightConverter): """Dense Gemma4 blocks store the pre-MLP norm at ``norm_2`` (drawn from the block's main ``normalization`` config). MoE blocks suppress this — the routed/dense branches inside the - hybrid MoE own their own pre/post norms. + hybrid MoE own their own pre/post norms. The conditional emit reads a sibling attribute + (``config.mlp``) and can't be expressed via :class:`NestedWeightConverter.optional`. """ def __init__(self) -> None: @@ -114,32 +95,8 @@ def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): return [] return LlamaNormalizationConverter.emit_weight_converters( config.normalization, - _join_prefix(fast_llm_prefix, "norm_2"), - _join_prefix(hf_prefix, "pre_feedforward_layernorm"), - root_config=root_config, - ) - - -class _Gemma4HybridMoENormWeightConverter(WeightConverter): - """Emit a normalization config nested two attributes deep inside a hybrid MoE block - (e.g. ``config.dense.pre_norm``, ``config.routed.post_norm``). The single-level - :class:`NestedWeightConverter` can't express the chained descent — :class:`MLPConfig` and - :class:`MoEMLPConfig` each carry their own pre/post norms, but those branches live one level - below the hybrid MoE section root. Gemma4's hybrid MoE always sets these norms. - """ - - def __init__(self, branch: str, norm_attr: str, hf_name: str) -> None: - super().__init__((), ()) - self._branch = branch - self._norm_attr = norm_attr - self._hf_name = hf_name - - def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): - norm_config = getattr(getattr(config, self._branch), self._norm_attr) - return LlamaNormalizationConverter.emit_weight_converters( - norm_config, - _join_prefix(fast_llm_prefix, f"{self._branch}.{self._norm_attr}"), - _join_prefix(hf_prefix, self._hf_name), + join_prefix(fast_llm_prefix, "norm_2"), + join_prefix(hf_prefix, "pre_feedforward_layernorm"), root_config=root_config, ) @@ -423,15 +380,29 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "dense": NestedWeightConverter("dense", "mlp", Gemma4MLPConverter), # Routed branch lives at the block's HF root (sibling of ``mlp.<...>``). "routed": NestedWeightConverter("routed", "", Gemma4MoEMLPConverter), - "dense_pre_norm": _Gemma4HybridMoENormWeightConverter("dense", "pre_norm", "pre_feedforward_layernorm"), - "dense_post_norm": _Gemma4HybridMoENormWeightConverter( - "dense", "post_norm", "post_feedforward_layernorm_1" + "dense_pre_norm": NestedWeightConverter( + "dense.pre_norm", + "pre_feedforward_layernorm", + LlamaNormalizationConverter, + config_attr=("dense", "pre_norm"), + ), + "dense_post_norm": NestedWeightConverter( + "dense.post_norm", + "post_feedforward_layernorm_1", + LlamaNormalizationConverter, + config_attr=("dense", "post_norm"), ), - "routed_pre_norm": _Gemma4HybridMoENormWeightConverter( - "routed", "pre_norm", "pre_feedforward_layernorm_2" + "routed_pre_norm": NestedWeightConverter( + "routed.pre_norm", + "pre_feedforward_layernorm_2", + LlamaNormalizationConverter, + config_attr=("routed", "pre_norm"), ), - "routed_post_norm": _Gemma4HybridMoENormWeightConverter( - "routed", "post_norm", "post_feedforward_layernorm_2" + "routed_post_norm": NestedWeightConverter( + "routed.post_norm", + "post_feedforward_layernorm_2", + LlamaNormalizationConverter, + config_attr=("routed", "post_norm"), ), } @@ -495,7 +466,13 @@ def export_config( def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { "mixer": NestedWeightConverter("mixer", "self_attn", Gemma4AttentionConverter), - "mlp": _Gemma4BlockMLPWeightConverter(), + # Dense MLP lands under ``mlp.<...>``; hybrid MoE flat-merges into the block's HF root. + "mlp": DispatchWeightConverter( + "mlp", + "mlp", + registry={MLPConfig: Gemma4MLPConverter, HybridMoEMLPConfig: Gemma4HybridMoEMLPConverter}, + hf_prefix_overrides={HybridMoEMLPConfig: ""}, + ), "norm_1": NestedWeightConverter( "norm_1", "input_layernorm", LlamaNormalizationConverter, config_attr="normalization" ), @@ -563,7 +540,7 @@ def export_config(cls, config: PatternBlockSequenceConfig, hidden_size: int) -> @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { - "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, read_self=True), + "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), } diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 0e92bceca..4d1deb6b3 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -20,6 +20,7 @@ NestedWeightConverter, OutputProjectionWeightConverter, RenameConfigConverter, + SelfBlockSequenceWeightConverter, SplitWeightConverter, TransposeSplitWeightConverter, WeightConverter, @@ -362,12 +363,11 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it) — - # ``read_self=True`` tells ``BlockSequenceWeightConverter`` to read the section config directly. - # Used by Pixtral's vision encoder and Apriel2's vision encoder; text formats inline the dispatch - # at the base-model converter instead. + # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it). Used by + # Pixtral's vision encoder and Apriel2's vision encoder; text formats inline the dispatch at the + # base-model converter instead. return { - "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, read_self=True), + "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), } diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index d3f597a65..6cbefa0c2 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -18,6 +18,7 @@ OutputProjectionWeightConverter, PatchEmbeddingWeightConverter, RenameConfigConverter, + SelfBlockSequenceWeightConverter, TransposeSplitWeightConverter, WeightConverter, ) @@ -253,10 +254,10 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # The section config IS the FixedBlockSequenceConfig — ``read_self=True`` makes - # BlockSequenceWeightConverter read the section config directly instead of via ``getattr``. + # The section config IS the FixedBlockSequenceConfig — SelfBlockSequenceWeightConverter reads + # the section config directly instead of via ``getattr``. return { - "blocks": BlockSequenceWeightConverter("", "", cls.block_converter_class, read_self=True), + "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), } @@ -496,10 +497,6 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "head": NestedWeightConverter("head", "", cls.head_converter_class), } - @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig) -> list[WeightConverter]: - return cls.emit_weight_converters(config, "", "") - class Apriel2HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: MultiModalModel diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 1df64465c..ffd84173d 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -404,10 +404,6 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "head": NestedWeightConverter("head", "", text_base_cls.head_converter_class), } - @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig) -> list[WeightConverter]: - return cls.emit_weight_converters(config, "", "") - class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): _model: MultiModalModel From 968ec6954744616fd212a924abbc0e2f21bab96f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 25 May 2026 14:37:29 -0400 Subject: [PATCH 08/12] Add weight_only flag for config-side-empty section converters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six Gemma4 section converters owned their config side only nominally — the actual conversion lives on the aggregator's CustomConfigConverter (fast_llm_recurses=True). Each declared an empty-path IgnoredConfigConverter blanket-claim solely to silence the static architecture-coverage walker. Replace the boilerplate with an explicit weight_only ClassVar on ConfigSectionConverter that short-circuits both _create_config_converters (empty default) and check_architecture_coverage. --- fast_llm/engine/checkpoint/external.py | 14 +++++++++ fast_llm/models/gpt/conversion/gemma4.py | 40 ++++++++---------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 27a2ceea5..7706f19be 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -559,6 +559,12 @@ class ConfigSectionConverter(abc.ABC): fast_llm_config_class: typing.ClassVar[type[Config]] hf_type_name: typing.ClassVar[str | None] = None + weight_only: typing.ClassVar[bool] = False + """When ``True``, the section contributes weight conversions only — its config side is owned by an + ancestor (typically via :class:`CustomConfigConverter` with ``fast_llm_recurses=True``). + :meth:`_create_config_converters` defaults to no declarations and :meth:`check_architecture_coverage` + short-circuits, so the section does not need to claim its own architecture leaves. + """ @classmethod @functools.cache @@ -568,7 +574,10 @@ def _create_config_converters(cls) -> dict[str, ConfigConverter]: Cached per class — declarations are immutable and depend only on ``cls``. Subclasses must build and return a *fresh* dict (idiomatically ``{**super()._create_config_converters(), ...}``); mutating the returned dict in place would corrupt the parent's cache entry for every subsequent caller. + Weight-only sections (``weight_only=True``) inherit the empty default. """ + if cls.weight_only: + return {} raise NotImplementedError @classmethod @@ -718,7 +727,12 @@ def check_architecture_coverage(cls, config: Config) -> None: Invoked from a test fixture (``tests/models/test_converters.py``) — not from the production export/import paths. Architecture coverage is a structural invariant of the converter declarations, so it only needs to hold once per (converter, config-class) pair, not on every save. + + Short-circuits for weight-only sections (``weight_only=True``): the config side is owned by an + ancestor declaration, so the architecture leaves are claimed there. """ + if cls.weight_only: + return Assert.is_(type(config), cls.fast_llm_config_class) declarations = cls._create_config_converters() explicit_paths: set[tuple[str, ...]] = set() diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index c511d624e..47fe8ed66 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -130,15 +130,12 @@ class Gemma4AttentionConverter(ConfigSectionConverter): to a single ``k_proj`` when ``shared_key_value`` is set. The config side is owned by :class:`Gemma4BaseModelConverter`'s ``decoder`` :class:`CustomConfigConverter` - (with ``fast_llm_recurses=True``); the blanket-claim below silences the static architecture-coverage - walker — Gemma4's sliding/full divergence prevents a uniform declarative shape per single block. + (with ``fast_llm_recurses=True``); the ``weight_only`` flag below signals that — Gemma4's sliding/full + divergence prevents a uniform declarative shape per single block. """ fast_llm_config_class = AttentionConfig - - @classmethod - def _create_config_converters(cls) -> dict: - return {"_blanket": IgnoredConfigConverter(())} + weight_only: typing.ClassVar[bool] = True @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: @@ -236,12 +233,9 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class Gemma4MLPConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. fast_llm_config_class = MLPConfig - - @classmethod - def _create_config_converters(cls) -> dict: - # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. - return {"_blanket": IgnoredConfigConverter(())} + weight_only: typing.ClassVar[bool] = True @classmethod def import_config(cls, config: dict) -> dict: @@ -278,11 +272,9 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class Gemma4MoEMLPConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. fast_llm_config_class = MoEMLPConfig - - @classmethod - def _create_config_converters(cls) -> dict: - return {"_blanket": IgnoredConfigConverter(())} + weight_only: typing.ClassVar[bool] = True @classmethod def import_config(cls, config: dict) -> dict: @@ -343,11 +335,9 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class Gemma4HybridMoEMLPConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. fast_llm_config_class = HybridMoEMLPConfig - - @classmethod - def _create_config_converters(cls) -> dict: - return {"_blanket": IgnoredConfigConverter(())} + weight_only: typing.ClassVar[bool] = True @classmethod def import_config(cls, config: dict) -> dict: @@ -408,11 +398,9 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class Gemma4BlockConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. fast_llm_config_class = DecoderBlockConfig - - @classmethod - def _create_config_converters(cls) -> dict: - return {"_blanket": IgnoredConfigConverter(())} + weight_only: typing.ClassVar[bool] = True @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: @@ -496,14 +484,12 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class Gemma4DecoderConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. fast_llm_config_class = PatternBlockSequenceConfig + weight_only: typing.ClassVar[bool] = True block_converter_class: typing.ClassVar[type[Gemma4BlockConverter]] = Gemma4BlockConverter - @classmethod - def _create_config_converters(cls) -> dict: - return {"_blanket": IgnoredConfigConverter(())} - @classmethod def import_config(cls, config: dict) -> dict: layer_types = config["layer_types"] From 6a267707d8d1c80816c9d94f504bd202f0483de8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 25 May 2026 17:29:43 -0400 Subject: [PATCH 09/12] Review-fine cleanup pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Delete dead helpers: get_apriel2_decoder_converter and two unreferenced _get_weight_converters classmethods (the latter also broken — called get_converters with a now-removed two-arg signature). - Inline the now-unconditional ConfigSectionConverter coverage call in HuggingfaceStateDictCheckpointHandler._check_hf_coverage; every concrete base_model_converter_class is one. - Drop the _effective_bias import alias in apriel2 (no name conflict). - Trim docstrings that referenced the previous (removed) implementation: OutputProjectionWeightConverter / LinearWeightConverter / LlavaHeadConverter. - Accept bias_fn=True/False bool literals on LinearWeightConverter; replaces ~10 `lambda c: False` callsites including all `no_bias = lambda c: False` named bindings. - Hoist orphan trailing comments on Apriel2Fixed/PatternDecoderConverter into class docstrings. --- fast_llm/engine/checkpoint/external.py | 18 +++---- fast_llm/engine/checkpoint/huggingface.py | 6 +-- fast_llm/models/gpt/conversion/apriel.py | 19 +++---- fast_llm/models/gpt/conversion/apriel2.py | 53 +++++++------------ fast_llm/models/gpt/conversion/gemma4.py | 12 ++--- fast_llm/models/gpt/conversion/qwen2.py | 2 +- .../models/multimodal/conversion/apriel2.py | 6 +-- .../models/multimodal/conversion/llava.py | 10 ++-- 8 files changed, 47 insertions(+), 79 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 7706f19be..b26e99d2d 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -928,8 +928,7 @@ class OutputProjectionWeightConverter(WeightConverter): """Marker for the LM-head output projection (typically ``head.output_weights`` ↔ ``lm_head.weight``). When the root config has ``tied_embedding_weight=True``, the walker drops this declaration entirely — - HF stores tied embeddings as just ``embed_tokens.weight`` with no separate ``lm_head.weight``. Replaces - the per-call ``drop_on_export=exported_config["tie_word_embeddings"]`` plumbing. + HF stores tied embeddings as just ``embed_tokens.weight`` with no separate ``lm_head.weight``. """ def _emit( @@ -1237,16 +1236,14 @@ def _emit( class LinearWeightConverter(WeightConverter): """Bundle a linear layer's ``.weight`` and (conditionally) ``.bias`` declarations into one entry. - Bias presence is resolved at emission time from the live section config: ``bias_fn(config)`` returns - a bool. The default reads ``config.add_linear_biases`` — the shared flag every Llama-style attention/MLP - section carries. Sections with per-layer overrides (e.g. Apriel Mamba's ``dt_layer`` / ``convolution_layer``) - pass a lambda that resolves the override. + Bias presence is resolved at emission time from the live section config: ``bias_fn`` is either a bool + literal (always / never) or a callable returning a bool. The default reads ``config.add_linear_biases`` — + the shared flag every Llama-style attention/MLP section carries. Sections with per-layer overrides (e.g. + Apriel Mamba's ``dt_layer`` / ``convolution_layer``) pass a lambda that resolves the override. ``transform`` selects the leaf class for both weight and bias: :class:`WeightConverter` for plain rename (the default), :class:`SplitWeightConverter` for fused → split, :class:`KeyValueWeightConverter` for fused KV → separate K/V, :class:`TransposeSplitWeightConverter` for MLP down-projection. - - Replaces the imperative ``get_weight_and_bias_converters`` / ``effective_bias`` helpers. """ def __init__( @@ -1255,7 +1252,7 @@ def __init__( hf_prefix: str | tuple[str, ...] | typing.Callable[[Config], str | tuple[str, ...]], *, transform: type[WeightConverter] = WeightConverter, - bias_fn: typing.Callable[[Config], bool] = lambda c: c.add_linear_biases, + bias_fn: bool | typing.Callable[[Config], bool] = lambda c: c.add_linear_biases, ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix @@ -1278,7 +1275,8 @@ def _emit( weight_fast_llm = prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.weight",)) weight_hf = prepend_prefix(hf_prefix, tuple(f"{p}.weight" for p in hf_prefixes)) emitted: list[WeightConverter] = [self._transform(weight_fast_llm, weight_hf, config)] - if self._bias_fn(config): + has_bias = self._bias_fn if isinstance(self._bias_fn, bool) else self._bias_fn(config) + if has_bias: bias_fast_llm = prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.bias",)) bias_hf = prepend_prefix(hf_prefix, tuple(f"{p}.bias" for p in hf_prefixes)) emitted.append(self._transform(bias_fast_llm, bias_hf, config)) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 59b10e389..5b4400a4b 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -10,7 +10,6 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig from fast_llm.engine.checkpoint.external import ( - ConfigSectionConverter, ExternalStateDictCheckpointHandler, WeightConverter, logger, @@ -165,13 +164,10 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None: """Run the HF-side coverage check at the import boundary. - Skips silently when the format's base-model converter isn't a :class:`ConfigSectionConverter` - (e.g. multimodal aggregators built on top of imperative ``HuggingFaceBaseModelConverter``). Subclasses that override :meth:`_import_config` should call this explicitly to keep the check active. """ - if issubclass(cls.base_model_converter_class, ConfigSectionConverter): - cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) + cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index 8ec56349c..ea8e2cc9e 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -189,17 +189,15 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # GDN has no linear biases — explicit ``bias_fn=lambda c: False`` since GatedDeltaNetConfig has - # no ``add_linear_biases`` field for the default to read. - no_bias = lambda c: False + # GDN has no ``add_linear_biases`` field; pass ``bias_fn=False`` so the default isn't consulted. return { - "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=no_bias), - "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=no_bias), - "convolution": LinearWeightConverter("convolution", "convolution", bias_fn=no_bias), - "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=no_bias), + "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=False), + "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=False), + "convolution": LinearWeightConverter("convolution", "convolution", bias_fn=False), + "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=False), "A_log": WeightConverter("A_log", "A_log"), "dt_bias": WeightConverter("dt_bias", "dt_bias"), - "norm": LinearWeightConverter("norm", "norm", bias_fn=no_bias), + "norm": LinearWeightConverter("norm", "norm", bias_fn=False), } @@ -249,7 +247,6 @@ def _create_config_converters(cls) -> dict: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: # KimiDeltaAttention has no linear biases. - no_bias = lambda c: False proj_names = ( "q_proj", "k_proj", @@ -265,10 +262,10 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "o_proj", ) return { - **{name: LinearWeightConverter(name, name, bias_fn=no_bias) for name in proj_names}, + **{name: LinearWeightConverter(name, name, bias_fn=False) for name in proj_names}, "A_log": WeightConverter("A_log", "A_log"), "dt_bias": WeightConverter("dt_bias", "dt_bias"), - "norm": LinearWeightConverter("norm", "norm", bias_fn=no_bias), + "norm": LinearWeightConverter("norm", "norm", bias_fn=False), } diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index f546eac8e..de3d3795c 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -48,8 +48,8 @@ LlamaEmbeddingsConverter, LlamaNormalizationConverter, assert_no_peft, + effective_bias, ) -from fast_llm.models.gpt.conversion.llama import effective_bias as _effective_bias from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts @@ -185,23 +185,23 @@ def _create_config_converters(cls) -> dict: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: # Each linear layer carries its own ``bias.enabled`` override; the default falls back to the - # mixer-wide ``add_linear_biases`` via :func:`_effective_bias`. ``key_value`` is biased only when + # mixer-wide ``add_linear_biases`` via :func:`effective_bias`. ``key_value`` is biased only when # both K and V agree (Fast-LLM packs them as a single tensor). return { "query": LinearWeightConverter( - "query", "q_proj", bias_fn=lambda c: _effective_bias(c.query_layer, c.add_linear_biases) + "query", "q_proj", bias_fn=lambda c: effective_bias(c.query_layer, c.add_linear_biases) ), "key_value": LinearWeightConverter( "key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter, bias_fn=lambda c: ( - _effective_bias(c.key_layer, c.add_linear_biases) - and _effective_bias(c.value_layer, c.add_linear_biases) + effective_bias(c.key_layer, c.add_linear_biases) + and effective_bias(c.value_layer, c.add_linear_biases) ), ), "dense": LinearWeightConverter( - "dense", "o_proj", bias_fn=lambda c: _effective_bias(c.dense_layer, c.add_linear_biases) + "dense", "o_proj", bias_fn=lambda c: effective_bias(c.dense_layer, c.add_linear_biases) ), } @@ -316,14 +316,13 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - no_bias = lambda c: False return { - "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=no_bias), - "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=no_bias), + "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=False), + "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=False), "convolution": LinearWeightConverter( "convolution", "convolution", bias_fn=lambda c: c.convolution_layer.bias.enabled ), - "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=no_bias), + "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=False), "dt_bias": WeightConverter("dt_bias", "dt_bias"), "A_log": WeightConverter("A_log", "A_log"), "norm": NestedWeightConverter("norm", "norm", LlamaNormalizationConverter, config_attr="normalization"), @@ -376,7 +375,6 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - no_bias = lambda c: False proj_names = ( "q_proj", "k_proj", @@ -392,7 +390,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "o_proj", ) return { - **{name: LinearWeightConverter(name, name, bias_fn=no_bias) for name in proj_names}, + **{name: LinearWeightConverter(name, name, bias_fn=False) for name in proj_names}, "A_log": WeightConverter("A_log", "A_log"), "dt_bias": WeightConverter("dt_bias", "dt_bias"), "norm": NestedWeightConverter("norm", "norm", LlamaNormalizationConverter, config_attr="normalization"), @@ -533,13 +531,13 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "layer_1", lambda c: ("gate_proj", "up_proj") if c.gated else ("up_proj",), transform=SplitWeightConverter, - bias_fn=lambda c: _effective_bias(c.layer_1, c.add_linear_biases), + bias_fn=lambda c: effective_bias(c.layer_1, c.add_linear_biases), ), "layer_2": LinearWeightConverter( "layer_2", "down_proj", transform=TransposeSplitWeightConverter, - bias_fn=lambda c: _effective_bias(c.layer_2, c.add_linear_biases), + bias_fn=lambda c: effective_bias(c.layer_2, c.add_linear_biases), ), } @@ -598,6 +596,10 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class Apriel2FixedDecoderConverter(ConfigSectionConverter): + """Config-only section: block fan-out lives on the base-model converter via + :class:`BlockSequenceWeightConverter`. This class exists for the config-side + :class:`DispatchConfigConverter` between Fixed and Pattern decoder shapes.""" + fast_llm_config_class = FixedBlockSequenceConfig hf_type_name = "fixed" block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @@ -609,13 +611,10 @@ def _create_config_converters(cls) -> dict: "block": NestedConfigConverter(("block",), cls.block_converter_class, hf_path=("block",)), } - # The block fan-out lives on the base-model converter, which uses :class:`BlockSequenceWeightConverter` - # directly (Fixed/Pattern dispatch and block iteration share one primitive). The Fixed/Pattern decoder - # section converters exist for the config side (dispatch via :class:`DispatchConfigConverter`) and - # contribute no weights of their own. - class Apriel2PatternDecoderConverter(ConfigSectionConverter): + """Config-only section; see :class:`Apriel2FixedDecoderConverter`.""" + fast_llm_config_class = PatternBlockSequenceConfig hf_type_name = "pattern" block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @@ -632,8 +631,6 @@ def _create_config_converters(cls) -> dict: ), } - # See note on :class:`Apriel2FixedDecoderConverter` — block fan-out lives on the base-model converter. - APRIEL2_DECODER_REGISTRY: dict[type[Config], type[ConfigSectionConverter]] = { FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, @@ -641,16 +638,6 @@ def _create_config_converters(cls) -> dict: } -def get_apriel2_decoder_converter( - decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, -) -> type[ConfigSectionConverter]: - """Look up the Apriel2 per-shape decoder converter for a given decoder config instance.""" - converter_class = APRIEL2_DECODER_REGISTRY.get(type(decoder_config)) - if converter_class is None: - raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") - return converter_class - - class Apriel2HeadConverter(ConfigSectionConverter): fast_llm_config_class = LanguageModelHeadConfig @@ -775,7 +762,3 @@ def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: cls._check_hf_coverage(config) return {"base_model": cls.base_model_converter_class.import_config(config)} - - @classmethod - def _get_weight_converters(cls, config: GPTModelConfig, export_config: dict) -> list[WeightConverter]: - return cls.base_model_converter_class.get_converters(config.base_model, export_config) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 47fe8ed66..bd3cc047e 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -218,14 +218,14 @@ def export_config(cls, sliding_config: AttentionConfig, full_config: AttentionCo @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { - "query": LinearWeightConverter("query", "q_proj", bias_fn=lambda c: False), + "query": LinearWeightConverter("query", "q_proj", bias_fn=False), "key_value": LinearWeightConverter( "key_value", lambda c: "k_proj" if c.shared_key_value else ("k_proj", "v_proj"), transform=_Gemma4SharedKeyValueWeightConverter, - bias_fn=lambda c: False, + bias_fn=False, ), - "dense": LinearWeightConverter("dense", "o_proj", bias_fn=lambda c: False), + "dense": LinearWeightConverter("dense", "o_proj", bias_fn=False), # ``value_norm`` is :class:`FixedRMSNormConfig` (no learnable weight) — not declared. "query_norm": NestedWeightConverter("query_norm", "q_norm", LlamaNormalizationConverter, optional=True), "key_norm": NestedWeightConverter("key_norm", "k_norm", LlamaNormalizationConverter, optional=True), @@ -263,10 +263,10 @@ def export_config(cls, config: MLPConfig) -> dict: def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { "layer_1": LinearWeightConverter( - "layer_1", ("gate_proj", "up_proj"), transform=SplitWeightConverter, bias_fn=lambda c: False + "layer_1", ("gate_proj", "up_proj"), transform=SplitWeightConverter, bias_fn=False ), "layer_2": LinearWeightConverter( - "layer_2", "down_proj", transform=TransposeSplitWeightConverter, bias_fn=lambda c: False + "layer_2", "down_proj", transform=TransposeSplitWeightConverter, bias_fn=False ), } @@ -326,7 +326,7 @@ def export_config(cls, config: MoEMLPConfig, hidden_size: int) -> dict: def _create_weight_converters(cls) -> dict[str, WeightConverter]: # ``router.norm`` is :class:`FixedRMSNormConfig` (no learnable weight) — not declared. return { - "router": LinearWeightConverter("router", "router.proj", bias_fn=lambda c: False), + "router": LinearWeightConverter("router", "router.proj", bias_fn=False), "router_scale": WeightConverter("router_scale", "router.scale"), "router_per_expert_scale": WeightConverter("router_per_expert_scale", "router.per_expert_scale"), "layer_1": Gemma4MoELayer1Converter("layer_1.weight", "experts.gate_up_proj"), diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index b5f4bf0dd..e321d0d31 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -81,7 +81,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "key_value": LinearWeightConverter( "key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter, bias_fn=lambda c: True ), - "dense": LinearWeightConverter("dense", "o_proj", bias_fn=lambda c: False), + "dense": LinearWeightConverter("dense", "o_proj", bias_fn=False), } diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 6cbefa0c2..1c8511b11 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -308,7 +308,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "patch_embeddings", "patch_embeddings", transform=PatchEmbeddingWeightConverter, - bias_fn=lambda c: False, + bias_fn=False, ), "normalization": NestedWeightConverter( "normalization", "normalization", LlamaNormalizationConverter, config_attr="normalization" @@ -541,7 +541,3 @@ def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: cls._check_hf_coverage(config) return {"base_model": cls.base_model_converter_class.import_config(config)} - - @classmethod - def _get_weight_converters(cls, config: MultiModalModelConfig, export_config: dict) -> list[WeightConverter]: - return cls.base_model_converter_class.get_converters(config.base_model, export_config) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index ffd84173d..6d6db9a51 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -162,7 +162,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "patch_embeddings", "patch_conv", transform=PatchEmbeddingWeightConverter, - bias_fn=lambda c: False, + bias_fn=False, ), # ``PixtralEmbeddingsConverter``'s section config IS the ``PatchEmbeddingsConfig`` (carries the # ``normalization`` sub-config directly), so the nested ``LlamaNormalizationConverter`` reads @@ -291,11 +291,9 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: class LlavaHeadConverter(LlamaHeadConverter): - # Llava always writes ``lm_head.weight`` on export (never dropped, even when ``tied_embedding_weight=True``); - # the parent's :class:`OutputProjectionWeightConverter` would also drop on export, so we replace it with a - # plain rename. When the HF state-dict lacks ``lm_head.weight`` (tied case), the handler's per-converter - # ``all(name in state_dict)`` check makes the rename a silent no-op on import — equivalent to the previous - # ``drop_on_import=tied`` behaviour, without the extra parameter plumbing. + # Llava always emits a separate ``language_model.lm_head.weight`` declaration even when + # ``tied_embedding_weight=True``, so the head uses a plain rename instead of + # :class:`OutputProjectionWeightConverter` (which drops on export under the tied flag). @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: From 15e67c288cc7f18004f357dab65b9f7f047d8f92 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 26 May 2026 12:40:26 -0400 Subject: [PATCH 10/12] Make HuggingFaceBaseModelConverter inherit ConfigSectionConverter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes the `# type: ignore[attr-defined]` on `cls.emit_weight_converters` in `huggingface.py` and the explicit `(ConfigSectionConverter, ...)` multi-inheritance on every concrete BaseModelConverter — the mixin relationship is now declared at the right level. Also fold `_Gemma4SharedKeyValueWeightConverter` onto `KeyValueWeightConverter` (delegates to `super()` on the non-shared branch instead of duplicating the chunk/cat logic) and extend the `weight_only` docstring to acknowledge that type-dispatched sections (Apriel/Apriel2) achieve the same effect through their own structure. Co-Authored-By: Claude Opus 4.7 --- fast_llm/engine/checkpoint/external.py | 5 +++++ fast_llm/engine/checkpoint/huggingface.py | 19 ++++++++----------- fast_llm/models/gpt/conversion/apriel2.py | 2 +- fast_llm/models/gpt/conversion/gemma4.py | 15 ++++++--------- fast_llm/models/gpt/conversion/llama.py | 2 +- .../models/multimodal/conversion/apriel2.py | 2 +- .../models/multimodal/conversion/llava.py | 2 +- 7 files changed, 23 insertions(+), 24 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index b26e99d2d..744663a6e 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -564,6 +564,11 @@ class ConfigSectionConverter(abc.ABC): ancestor (typically via :class:`CustomConfigConverter` with ``fast_llm_recurses=True``). :meth:`_create_config_converters` defaults to no declarations and :meth:`check_architecture_coverage` short-circuits, so the section does not need to claim its own architecture leaves. + + Sections whose ancestor isn't a recursive :class:`CustomConfigConverter` (e.g. Apriel/Apriel2's + type-dispatched blocks) handle the same situation through their own structure — the dispatching + primitive (Nested/Dispatch/TypedDictContainer) claims the subtree recursively — and don't need this + flag. """ @classmethod diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 5b4400a4b..23b602a53 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -10,6 +10,7 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig from fast_llm.engine.checkpoint.external import ( + ConfigSectionConverter, ExternalStateDictCheckpointHandler, WeightConverter, logger, @@ -22,16 +23,12 @@ import transformers -class HuggingFaceBaseModelConverter: - @classmethod - @abc.abstractmethod - def import_config(cls, config: dict) -> dict: - pass - - @classmethod - @abc.abstractmethod - def export_config(cls, config: BaseModelConfig) -> dict: - pass +class HuggingFaceBaseModelConverter(ConfigSectionConverter): + """Section converter for a full HF model root. Inherits the declarative config-side machinery from + :class:`ConfigSectionConverter` (``import_config`` / ``export_config`` driven by + ``_create_config_converters``) and adds the weight-side ``get_converters`` aggregation that the + enclosing :class:`HuggingfaceStateDictCheckpointHandler` invokes. + """ @classmethod def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: @@ -41,7 +38,7 @@ def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: :class:`LlamaBaseModelConverter` splices the head's weights separately so MTP-Llama's per-prediction-head fan-out has access to the full base-model config. """ - return cls.emit_weight_converters(config, "", "") # type: ignore[attr-defined] + return cls.emit_weight_converters(config, "", "") class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index de3d3795c..956225b40 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -676,7 +676,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: } -class Apriel2BaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class Apriel2BaseModelConverter(HuggingFaceBaseModelConverter): fast_llm_config_class = GPTBaseModelConfig embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index bd3cc047e..2153a4220 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -3,8 +3,6 @@ import functools import typing -import torch - from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -13,6 +11,7 @@ CustomConfigConverter, DispatchWeightConverter, IgnoredConfigConverter, + KeyValueWeightConverter, LinearWeightConverter, NestedWeightConverter, RenameConfigConverter, @@ -101,9 +100,9 @@ def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): ) -class _Gemma4SharedKeyValueWeightConverter(WeightConverter): +class _Gemma4SharedKeyValueWeightConverter(KeyValueWeightConverter): """``shared_key_value=True`` Gemma4 attention: Fast-LLM's ``key_value`` is a single K-shaped - tensor (V is reused at runtime) and maps to a single HF ``k_proj`` — plain rename. Falls back to + tensor (V is reused at runtime) and maps to a single HF ``k_proj`` — plain rename. Delegates to :class:`KeyValueWeightConverter` (chunk/cat across K and V) when not shared. """ @@ -112,14 +111,12 @@ class _Gemma4SharedKeyValueWeightConverter(WeightConverter): def export_weight(self, weight): if self._config.shared_key_value: return weight - (key_value,) = weight - return key_value[:].chunk(2) + return super().export_weight(weight) def import_weight(self, weight): if self._config.shared_key_value: return weight - key, value = weight - return (torch.cat([key[:], value[:]]),) + return super().import_weight(weight) class Gemma4AttentionConverter(ConfigSectionConverter): @@ -582,7 +579,7 @@ def _gemma4_bidirectional_import(hf_dict: dict) -> dict: return {} -class Gemma4BaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class Gemma4BaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for ``GPTBaseModelConfig`` ↔ Gemma 4 HF dict. Gemma 4 has several wrinkles that prevent the standard per-section decomposition used by Llama: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 4d1deb6b3..4338b6c26 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -444,7 +444,7 @@ def get_converters( return cls.emit_weight_converters(config.head, "head", "", root_config=config) -class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class LlamaBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict. Subclasses (Mistral, Qwen2, Mixtral, MTP-Llama, …) override ``block_converter_class`` to plug their diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 1c8511b11..f5bc47dc3 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -417,7 +417,7 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: } -class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class Apriel2MultimodalBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for Apriel2 multimodal. Composes the Apriel2 text base (flat-merged into the HF top-level dict) with an optional vision encoder (under HF key ``vision_encoder``) and an optional ``image_token_index`` field. diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 6d6db9a51..1a67cf8ec 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -312,7 +312,7 @@ class LlavaLanguageModelConverter(MistralBaseModelConverter): head_converter_class: typing.ClassVar[type[LlavaHeadConverter]] = LlavaHeadConverter -class LlavaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class LlavaBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for Llava. Composes: * ``text_config`` HF subdict ← :class:`LlavaLanguageModelConverter` (Mistral text base). From 3a574e4e189642953d4698a36852adaa30060208 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 26 May 2026 13:27:31 -0400 Subject: [PATCH 11/12] Strip downstream-consumer references and code-restating comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per CLAUDE.md (don't reference downstream consumers in explanatory text, don't restate what the code says): primitive docstrings in ``external.py`` (Nested/DispatchBlockSequence/SelfBlockSequence/Dispatch/ TypedDict/Linear) and ``huggingface.py``'s ``get_converters`` default no longer name specific formats; ``effective_bias``, ``LlamaHeadConverter.get_converters``, and the head-aggregator comment on ``LlamaBaseModelConverter`` rephrase generically; the local restating comments in ``mtp_llama.py``, ``apriel.py``, and ``multimodal/apriel2.py`` are dropped. Also fixes three small docstring/typing issues: * ``KeyValueWeightConverter`` no longer claims "identity for bias" — biases are chunked/concatenated the same way as the weight. * ``WeightConverter._emit``'s example references ``NestedWeightConverter`` / ``LinearWeightConverter`` (real overrides that capture construction state) instead of ``KeyValueWeightConverter`` (which doesn't override ``_emit``). * ``_Gemma4SharedKeyValueWeightConverter`` and ``_Gemma4BlockNorm2WeightConverter._emit`` get the missing type annotations restored, mirroring the base class. Co-Authored-By: Claude Opus 4.7 --- fast_llm/engine/checkpoint/external.py | 59 +++++++++---------- fast_llm/engine/checkpoint/huggingface.py | 5 +- fast_llm/models/gpt/conversion/apriel.py | 2 - fast_llm/models/gpt/conversion/gemma4.py | 20 ++++++- fast_llm/models/gpt/conversion/llama.py | 20 +++---- fast_llm/models/gpt/conversion/mtp_llama.py | 1 - .../models/multimodal/conversion/apriel2.py | 2 - 7 files changed, 55 insertions(+), 54 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 744663a6e..4a6ff36ea 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -774,7 +774,7 @@ def check_architecture_coverage(cls, config: Config) -> None: def prepend_prefix(prefix: str, names: tuple[str, ...]) -> tuple[str, ...]: - """Prepend ``prefix`` to each name. Empty ``prefix`` is a no-op; empty ``names`` (drop side) stays empty.""" + """Prepend ``prefix`` to each name. Empty ``prefix`` is a no-op; empty ``names`` stays empty.""" if not prefix: return names return tuple(f"{prefix}.{name}" for name in names) @@ -831,8 +831,9 @@ def _emit( ) -> list["WeightConverter"]: """Return a fully-qualified emitted copy of this leaf. - Subclasses that capture extra construction state (e.g. :class:`KeyValueWeightConverter` stashing - an :class:`AttentionConfig`) override this hook to pass that state into the emitted copy. + Subclasses that capture extra construction state (e.g. :class:`NestedWeightConverter` / + :class:`LinearWeightConverter` holding a sub-converter class or a callable prefix) override + this hook to pass that state through to the emitted output. """ return [ type(self)( @@ -844,6 +845,8 @@ def _emit( class SplitWeightConverter(WeightConverter): + """Split a merged tensor evenly across the listed export names; concatenate on import.""" + def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: @@ -859,9 +862,8 @@ def import_weight( class TransposeSplitWeightConverter(WeightConverter): """Split a merged weight across the last dim with an additional transpose. - Equivalent to :class:`SplitWeightConverter` for non-gated MLPs (trivial split) and for 1-D biases - (trivial transpose); the real behaviour kicks in for the down-projection of a gated MLP where HF - stores the weight in transposed orientation. + The transpose only matters when the source tensor's last two dims differ; for 1-D biases and the + trivial single-name case it degenerates to a passthrough split. """ def export_weight( @@ -881,7 +883,7 @@ class KeyValueWeightConverter(WeightConverter): """Pack/unpack a fused key-value tensor across the two HF names. Fast-LLM packs key/value as a single concatenated tensor; HF stores them as two siblings - (``k_proj`` / ``v_proj``). Identity for bias because biases are concatenated the same way. + (``k_proj`` / ``v_proj``). The same split applies to the 1-D bias. """ def export_weight( @@ -954,17 +956,16 @@ class NestedWeightConverter(WeightConverter): The sub-section's config is read by chained ``getattr`` from ``config`` via ``config_attr``: - * ``None`` (default) — single attribute named after ``fast_llm_prefix`` (e.g. ``getattr(config, "mixer")``). - * single string — single attribute (covers a block's ``normalization`` config feeding two state-dict - prefixes ``norm_1`` / ``norm_2``). - * tuple of strings — chained descent (e.g. ``("dense", "pre_norm")`` ↦ ``config.dense.pre_norm`` for - Gemma4's hybrid-MoE inner norms). + * ``None`` (default) — single attribute named after ``fast_llm_prefix``. + * single string — single attribute (e.g. one ``normalization`` config feeding two state-dict + prefixes). + * tuple of strings — chained descent (e.g. ``("dense", "pre_norm")`` ↦ ``config.dense.pre_norm``). The walker descends into ``sub_converter_class._create_weight_converters()`` with extended prefixes. Mirrors :class:`NestedConfigConverter` on the config side. ``optional=True`` skips the recursion when the resolved sub-config is ``None`` — for optional - architecture sections like Llava's ``vision_encoder``. + architecture sub-sections. """ def __init__( @@ -1073,7 +1074,8 @@ class DispatchBlockSequenceWeightConverter(WeightConverter): """Fan a per-position-dispatched block converter across every position in a block sequence. Each position's block config is matched against ``dispatch_registry`` keys by its mixer type - (``type(block.mixer)``) — Apriel's hybrid-block dispatch. + (``type(block.mixer)``), allowing hybrid sequences where different layers use different block + converters. """ def __init__( @@ -1116,8 +1118,7 @@ class SelfBlockSequenceWeightConverter(WeightConverter): """Fan a single block converter across the section config when *the section IS the block sequence*. Used when the declaring section's ``fast_llm_config_class`` is itself a ``FixedBlockSequenceConfig`` - or ``PatternBlockSequenceConfig`` (e.g. ``LlamaDecoderConverter`` plugged into the Pixtral vision - encoder; Apriel2 and Gemma4's decoders). Weights land directly under the section's outer prefixes. + or ``PatternBlockSequenceConfig``. Weights land directly under the section's outer prefixes. """ def __init__(self, block_converter_class: type["ConfigSectionConverter"]): @@ -1149,11 +1150,10 @@ class DispatchWeightConverter(WeightConverter): Reads ``getattr(config, config_attr)`` (defaults to ``fast_llm_prefix``), looks up its type in ``registry``, and recurses into that ConfigSectionConverter with the standard extended prefixes. Mirrors :class:`DispatchConfigConverter` on the config side. Used when a single attribute holds one - of several alternative configs (e.g. Apriel2's block ``mixer`` may be attention/mamba/gdn/kda/stochastic; - its ``normalization`` may be RMS/Layer/None). + of several alternative configs. - ``hf_prefix_overrides`` lets individual branches replace the shared ``hf_prefix`` (e.g. Gemma4's - hybrid MoE flat-merges into the block root while dense MLP lands under ``mlp.<...>``). + ``hf_prefix_overrides`` lets individual branches replace the shared ``hf_prefix`` (e.g. when one + branch flat-merges into the parent root while siblings nest under a sub-prefix). """ def __init__( @@ -1197,8 +1197,7 @@ class TypedDictWeightConverter(WeightConverter): For each entry, looks up its type in ``registry`` and recurses into that converter with names ``{fast_llm_prefix}.{key}`` / ``{hf_prefix}.{key}``. Mirrors - :class:`TypedDictContainerConfigConverter` on the config side. Used for collections of named sub- - configs (e.g. Apriel2 StochasticMixer's ``mixers`` dict). + :class:`TypedDictContainerConfigConverter` on the config side. """ def __init__( @@ -1242,13 +1241,13 @@ class LinearWeightConverter(WeightConverter): """Bundle a linear layer's ``.weight`` and (conditionally) ``.bias`` declarations into one entry. Bias presence is resolved at emission time from the live section config: ``bias_fn`` is either a bool - literal (always / never) or a callable returning a bool. The default reads ``config.add_linear_biases`` — - the shared flag every Llama-style attention/MLP section carries. Sections with per-layer overrides (e.g. - Apriel Mamba's ``dt_layer`` / ``convolution_layer``) pass a lambda that resolves the override. + literal (always / never) or a callable returning a bool. The default reads + ``config.add_linear_biases``; sections with per-layer overrides pass a lambda that resolves the + override. - ``transform`` selects the leaf class for both weight and bias: :class:`WeightConverter` for plain rename - (the default), :class:`SplitWeightConverter` for fused → split, :class:`KeyValueWeightConverter` for - fused KV → separate K/V, :class:`TransposeSplitWeightConverter` for MLP down-projection. + ``transform`` selects the leaf class for both weight and bias: :class:`WeightConverter` for plain + rename (the default), :class:`SplitWeightConverter` for fused → split, :class:`KeyValueWeightConverter` + for fused KV → separate K/V, :class:`TransposeSplitWeightConverter` for a transposed split. """ def __init__( @@ -1261,8 +1260,8 @@ def __init__( ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix - # ``hf_prefix`` may be a callable (e.g. Mixtral's ``experts.{i}.w1``-style fan-out where the - # expert count comes from the live config). + # ``hf_prefix`` may be a callable when the HF names depend on a runtime-known shape (e.g. a + # count read from the live section config); resolved at emission time. self._hf_prefix = hf_prefix self._transform = transform self._bias_fn = bias_fn diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 23b602a53..c055a7f2c 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -34,9 +34,8 @@ class HuggingFaceBaseModelConverter(ConfigSectionConverter): def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: """Default: walk the section's weight declarations from the root. - Subclasses with constructs that don't fit the standard declaration walk override — e.g. - :class:`LlamaBaseModelConverter` splices the head's weights separately so MTP-Llama's - per-prediction-head fan-out has access to the full base-model config. + Subclasses override when a section needs cross-section state from the full base-model config + (typically when an extension point on the head must read from a sibling section). """ return cls.emit_weight_converters(config, "", "") diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index ea8e2cc9e..d024f7f89 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -425,8 +425,6 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # Override the parent's flat ``decoder`` entry with a per-position dispatch version that picks - # the right block converter from the dispatcher's registry based on the mixer's runtime type. return { **super()._create_weight_converters(), "decoder": DispatchBlockSequenceWeightConverter( diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 2153a4220..75d65911c 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -3,6 +3,8 @@ import functools import typing +import torch + from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( @@ -41,6 +43,7 @@ LlamaNormalizationConverter, ) from fast_llm.models.gpt.model import GPTModel +from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, safe_merge_dicts _SLIDING_ATTENTION = "sliding_attention" @@ -89,7 +92,14 @@ class _Gemma4BlockNorm2WeightConverter(WeightConverter): def __init__(self) -> None: super().__init__((), ()) - def _emit(self, config, fast_llm_prefix, hf_prefix, *, root_config): + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: if isinstance(config.mlp, HybridMoEMLPConfig): return [] return LlamaNormalizationConverter.emit_weight_converters( @@ -108,12 +118,16 @@ class _Gemma4SharedKeyValueWeightConverter(KeyValueWeightConverter): _config: AttentionConfig - def export_weight(self, weight): + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: if self._config.shared_key_value: return weight return super().export_weight(weight) - def import_weight(self, weight): + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: if self._config.shared_key_value: return weight return super().import_weight(weight) diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 4338b6c26..d57665e6c 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -57,11 +57,7 @@ def assert_no_peft(config: GPTBaseModelConfig) -> None: def effective_bias(layer_config: AffineLinearConfig, default: bool) -> bool: - """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default. - - Used by Apriel and Apriel2 config-side ``CustomConfigConverter`` export_fns (which need to translate - a per-layer override into an HF-side bias flag) and by their ``LinearWeightConverter.bias_fn`` lambdas. - """ + """Resolve a layer's effective bias flag: explicit ``bias.enabled`` if set, else the parent default.""" return default if layer_config.bias.enabled is None else layer_config.bias.enabled @@ -363,9 +359,7 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it). Used by - # Pixtral's vision encoder and Apriel2's vision encoder; text formats inline the dispatch at the - # base-model converter instead. + # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it). return { "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), } @@ -437,9 +431,9 @@ def get_converters( config: GPTBaseModelConfig, ) -> list[WeightConverter]: """Aggregator entry-point: the base-model converter passes the full :class:`GPTBaseModelConfig` - so subclasses (e.g. MTP-Llama) can read ``config.decoder.last_block_config`` / - ``config.head.prediction_heads`` when extending the head's weights. Tied-embedding handling - lives on :class:`OutputProjectionWeightConverter` and reads ``root_config.tied_embedding_weight``. + so subclasses extending the head can read sibling sections (e.g. the decoder) when needed. + Tied-embedding handling lives on :class:`OutputProjectionWeightConverter` and reads + ``root_config.tied_embedding_weight``. """ return cls.emit_weight_converters(config.head, "head", "", root_config=config) @@ -503,8 +497,8 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: # ``head`` is added at the aggregator level (in :meth:`get_converters`) because the head - # converter takes the full base-model config so subclasses can extend it (e.g. MTP-Llama - # fans out per-prediction-head blocks and norms). + # converter takes the full base-model config so subclasses extending the head can read + # sibling sections. return { "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), "decoder": BlockSequenceWeightConverter("decoder", "model.layers", cls.block_converter_class), diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 4513b5bd4..1db31c6cd 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -52,7 +52,6 @@ def get_converters( config: GPTBaseModelConfig, ) -> list[WeightConverter]: converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config)) - # Append the MTP fan-out: one block + one norm per extra prediction head. for prediction_distance in range(2, config.head.prediction_heads + 1): converters += cls.block_converter_class.emit_weight_converters( config.decoder.last_block_config, diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index f5bc47dc3..62d9561d1 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -254,8 +254,6 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # The section config IS the FixedBlockSequenceConfig — SelfBlockSequenceWeightConverter reads - # the section config directly instead of via ``getattr``. return { "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), } From 19259318e801397b9ae802ccd502c65c6f33b79a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 26 May 2026 18:41:24 -0400 Subject: [PATCH 12/12] =?UTF-8?q?Fine-pass=20cleanup:=20dead=20overrides,?= =?UTF-8?q?=20lambda=20=E2=86=92=20bool,=20stale=20comments?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Drop ``Apriel2MultimodalHeadConverter`` — its ``_create_weight_converters`` was byte-identical to the inherited ``Apriel2HeadConverter`` version, so the empty subclass adds nothing; ``head_converter_class`` now points directly at the parent. * Qwen2 attention ``query`` / ``key_value`` use ``bias_fn=True`` instead of ``lambda c: True`` — consistent with ``bias_fn=False`` on the dense branch right below, and clearer that no config inspection is happening. * Drop ``LinearWeightConverter.__init__``'s ``hf_prefix``-may-be-a-callable comment that restated the type annotation directly above it. * Update ``AprielBaseModelConverter.block_dispatcher_class`` field comment to name ``DispatchBlockSequenceWeightConverter`` (the actual declarative consumer) instead of the long-removed weight-side loop. Co-Authored-By: Claude Opus 4.7 --- fast_llm/engine/checkpoint/external.py | 2 -- fast_llm/models/gpt/conversion/apriel.py | 2 +- fast_llm/models/gpt/conversion/qwen2.py | 4 ++-- fast_llm/models/multimodal/conversion/apriel2.py | 15 +-------------- 4 files changed, 4 insertions(+), 19 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 4a6ff36ea..c913859fe 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1260,8 +1260,6 @@ def __init__( ): super().__init__((), ()) self._fast_llm_prefix = fast_llm_prefix - # ``hf_prefix`` may be a callable when the HF names depend on a runtime-known shape (e.g. a - # count read from the live section config); resolved at emission time. self._hf_prefix = hf_prefix self._transform = transform self._bias_fn = bias_fn diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index d024f7f89..103c42dae 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -406,7 +406,7 @@ class AprielBaseModelConverter(MistralBaseModelConverter): # Distinct from the parent's ``block_converter_class`` (a single ``ConfigSectionConverter``); this # one holds the per-mixer-type dispatch registries that :class:`ListDispatchConfigConverter` and - # the weight-side loop below consume. + # :class:`DispatchBlockSequenceWeightConverter` consume. block_dispatcher_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index e321d0d31..ec2a53c47 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -77,9 +77,9 @@ def _validate_export(cls, config: AttentionConfig) -> None: @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: return { - "query": LinearWeightConverter("query", "q_proj", bias_fn=lambda c: True), + "query": LinearWeightConverter("query", "q_proj", bias_fn=True), "key_value": LinearWeightConverter( - "key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter, bias_fn=lambda c: True + "key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter, bias_fn=True ), "dense": LinearWeightConverter("dense", "o_proj", bias_fn=False), } diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 62d9561d1..d7832db19 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -15,7 +15,6 @@ NestedConfigConverter, NestedWeightConverter, OptionalConfigConverter, - OutputProjectionWeightConverter, PatchEmbeddingWeightConverter, RenameConfigConverter, SelfBlockSequenceWeightConverter, @@ -403,18 +402,6 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: } -class Apriel2MultimodalHeadConverter(Apriel2HeadConverter): - @classmethod - @functools.cache - def _create_weight_converters(cls) -> dict[str, WeightConverter]: - return { - "final_norm": NestedWeightConverter( - "final_norm", "model.norm", cls.normalization_converter_class, config_attr="normalization" - ), - "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), - } - - class Apriel2MultimodalBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for Apriel2 multimodal. Composes the Apriel2 text base (flat-merged into the HF top-level dict) with an optional vision encoder (under HF key ``vision_encoder``) and an optional @@ -431,7 +418,7 @@ class Apriel2MultimodalBaseModelConverter(HuggingFaceBaseModelConverter): vision_model_converter_class: typing.ClassVar[type[Apriel2VisionModelConverter]] = Apriel2VisionModelConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter - head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter + head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod def _create_config_converters(cls) -> dict: