From ec8966861c68bbbd3302661d140cba566700bf13 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 13:49:37 -0400 Subject: [PATCH 1/7] Make Gemma4 base-model weight aggregation declarative Express the embeddings/decoder weight mapping as NestedWeightConverter declarations in _create_weight_converters and reduce get_converters to the canonical emit+head form, mirroring LlamaBaseModelConverter. No change to emitted converters. Co-Authored-By: Claude Opus 4.8 (1M context) --- fast_llm/models/gpt/conversion/gemma4.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 75d65911c..707da8999 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -701,15 +701,21 @@ def _head_import(hf_dict: dict) -> dict: "vocab_size_per_layer_input": IgnoredConfigConverter(hf_paths=(("vocab_size_per_layer_input",),)), } + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``head`` is added at the aggregator level (in :meth:`get_converters`) because the head + # converter takes the full base-model config so subclasses extending the head can read + # sibling sections. + return { + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": NestedWeightConverter("decoder", "model.layers", cls.decoder_converter_class), + } + @classmethod def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.emit_weight_converters( - config.embeddings, "embeddings", "model", root_config=config - ), - *cls.decoder_converter_class.emit_weight_converters( - config.decoder, "decoder", "model.layers", root_config=config - ), + *cls.emit_weight_converters(config, "", ""), *cls.head_converter_class.get_converters(config), ] From eee018d3ffa55a5cb57382e0add8961c9b78d615 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 13:49:39 -0400 Subject: [PATCH 2/7] Make MTP-Llama prediction-head stack declarative via RepeatWeightConverter Add RepeatWeightConverter, a structural primitive that fans a sub-section converter over a config-driven count with index-templated prefixes (reusing one sub-config per iteration, unlike BlockSequenceWeightConverter which walks a per-position config list). Use it to declare MTP-Llama's per-prediction-distance blocks and norms on the base-model converter (where they belong: they live at the model root as multi_token_prediction.*, not under head), dropping the imperative get_converters override on the head converter. Emitted converters are unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) --- fast_llm/engine/checkpoint/external.py | 50 +++++++++++++++++++ fast_llm/models/gpt/conversion/mtp_llama.py | 54 +++++++++++---------- 2 files changed, 79 insertions(+), 25 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index c913859fe..285c8fe19 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1144,6 +1144,56 @@ def _emit( return out +class RepeatWeightConverter(WeightConverter): + """Repeat a sub-section converter a config-driven number of times, with index-templated prefixes. + + Unlike :class:`BlockSequenceWeightConverter` — which fans a converter over a materialized per-position + config list — every iteration recurses into the *same* sub-config; only the emitted prefixes change + with the index. Used where a runtime count, not a block list, drives the repeat (e.g. a stack of + prediction-distance heads all sharing one block / normalization config). + + ``count`` and ``sub_config`` are resolved from the live section config. ``fast_llm_prefix`` and + ``hf_prefix`` map a 0-based iteration index to the section-relative prefixes; the two need not share + the same index arithmetic — e.g. an HF-side stack whose element 0 is declared elsewhere is reached by + offsetting the index here. + """ + + def __init__( + self, + sub_converter_class: type["ConfigSectionConverter"], + *, + count: typing.Callable[[Config], int], + sub_config: typing.Callable[[Config], Config], + fast_llm_prefix: typing.Callable[[int], str], + hf_prefix: typing.Callable[[int], str], + ): + super().__init__((), ()) + self._sub_converter_class = sub_converter_class + self._count = count + self._sub_config = sub_config + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + sub_config = self._sub_config(config) + out: list[WeightConverter] = [] + for index in range(self._count(config)): + out += self._sub_converter_class.emit_weight_converters( + sub_config, + join_prefix(fast_llm_prefix, self._fast_llm_prefix(index)), + join_prefix(hf_prefix, self._hf_prefix(index)), + root_config=root_config, + ) + return out + + class DispatchWeightConverter(WeightConverter): """Dispatch a single sub-section converter based on the live config's runtime type. diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 1db31c6cd..c79c60964 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -8,9 +8,10 @@ NestedWeightConverter, OutputProjectionWeightConverter, RenameConfigConverter, + RepeatWeightConverter, WeightConverter, ) -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, @@ -36,9 +37,9 @@ def _create_config_converters(cls) -> dict: @classmethod @functools.cache def _create_weight_converters(cls) -> dict[str, WeightConverter]: - # MTP-Llama places the first prediction head's final norm under ``model.mtp_norms.0`` instead - # of the standard ``model.norm``; the additional MTP blocks/norms come from the imperative - # ``get_converters`` override below since their count depends on ``head.prediction_heads``. + # MTP-Llama places the first prediction head's final norm under ``model.mtp_norms.0`` instead of + # the standard ``model.norm``. The additional per-prediction-distance blocks and norms are + # declared on the base-model converter (they live at the model root, not under ``head``). return { "final_norm": NestedWeightConverter( "final_norm", "model.mtp_norms.0", cls.normalization_converter_class, config_attr="normalization" @@ -46,31 +47,34 @@ def _create_weight_converters(cls) -> dict[str, WeightConverter]: "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), } - @classmethod - def get_converters( - cls, - config: GPTBaseModelConfig, - ) -> list[WeightConverter]: - converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config)) - for prediction_distance in range(2, config.head.prediction_heads + 1): - converters += cls.block_converter_class.emit_weight_converters( - config.decoder.last_block_config, - f"multi_token_prediction.blocks.{prediction_distance - 2}", - f"model.mtp_heads.{prediction_distance - 2}", - root_config=config, - ) - converters += cls.normalization_converter_class.emit_weight_converters( - config.head.normalization, - f"multi_token_prediction.heads.{prediction_distance - 2}.final_norm", - f"model.mtp_norms.{prediction_distance - 1}", - root_config=config, - ) - return converters - class MTPLlamaBaseModelConverter(LlamaBaseModelConverter): head_converter_class: typing.ClassVar[type[MTPLlamaHeadConverter]] = MTPLlamaHeadConverter + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # The extra prediction-distance heads (distances 2..prediction_heads) repeat the main decoder's + # last block and the head normalization. They sit at the model root as ``multi_token_prediction.*``, + # interleaved on the HF side with the base head's ``model.mtp_norms.0`` (declared on the head). + return { + **super()._create_weight_converters(), + "multi_token_prediction_blocks": RepeatWeightConverter( + cls.block_converter_class, + count=lambda config: config.head.prediction_heads - 1, + sub_config=lambda config: config.decoder.last_block_config, + fast_llm_prefix=lambda index: f"multi_token_prediction.blocks.{index}", + hf_prefix=lambda index: f"model.mtp_heads.{index}", + ), + "multi_token_prediction_norms": RepeatWeightConverter( + cls.head_converter_class.normalization_converter_class, + count=lambda config: config.head.prediction_heads - 1, + sub_config=lambda config: config.head.normalization, + fast_llm_prefix=lambda index: f"multi_token_prediction.heads.{index}.final_norm", + hf_prefix=lambda index: f"model.mtp_norms.{index + 1}", + ), + } + class MTPLlamaHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): format: typing.ClassVar[type[CheckpointFormat]] = MTPLlamaCheckpointFormat From 74c5db2edce5a3d2715a4d9c5f7461a942ebfb8e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 14:55:14 -0400 Subject: [PATCH 3/7] Gate optional Gemma4 transformers import in HF round-trip test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The top-level `from transformers import Gemma4ForCausalLM, Gemma4TextConfig` (added with Gemma 4 support in #492) makes the entire test_hf_roundtrip module fail to collect on transformers builds without Gemma 4 — taking the llama/mistral/qwen2/mixtral/mtp_llama round-trips down with it. Guard the import (matching the try/except pattern used elsewhere in tests) and skip any round-trip case whose model class is unavailable. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/models/test_hf_roundtrip.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/models/test_hf_roundtrip.py b/tests/models/test_hf_roundtrip.py index 48161dc99..681e05d9c 100644 --- a/tests/models/test_hf_roundtrip.py +++ b/tests/models/test_hf_roundtrip.py @@ -15,8 +15,6 @@ import torch from transformers import ( AutoConfig, - Gemma4ForCausalLM, - Gemma4TextConfig, LlamaConfig, LlamaForCausalLM, MistralConfig, @@ -54,6 +52,12 @@ from fast_llm_external_models.mtp_llama.configuration_mtp_llama import MTPLlamaConfig from fast_llm_external_models.mtp_llama.modeling_mtp_llama import MTPLlamaForCausalLM +try: + # Available only in transformers builds that ship Gemma 4. + from transformers import Gemma4ForCausalLM, Gemma4TextConfig +except ImportError: + Gemma4ForCausalLM = Gemma4TextConfig = None + @dataclasses.dataclass(frozen=True) class HFRoundtripCase: @@ -247,7 +251,19 @@ def make_model(self) -> PreTrainedModel: ] -@pytest.mark.parametrize("case", [pytest.param(c, id=c.name) for c in _HF_ROUNDTRIP_CASES]) +@pytest.mark.parametrize( + "case", + [ + pytest.param( + case, + id=case.name, + marks=pytest.mark.skipif( + case.model_class is None, reason="transformers build does not provide this model class" + ), + ) + for case in _HF_ROUNDTRIP_CASES + ], +) def test_hf_roundtrip(case: HFRoundtripCase, result_path: pathlib.Path): """HF model survives HF → Fast-LLM → HF with identical config and weights.""" base = result_path / "hf_roundtrip" / case.name From f663d7af104f728b7c956fa344e6afb01b264013 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 15:17:56 -0400 Subject: [PATCH 4/7] Derive HF metadata allowlist from PretrainedConfig for version robustness MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The HF-coverage check rejected configs from supported transformers builds: the hand-curated _HF_METADATA_ALLOWLIST omits generic PretrainedConfig fields (torchscript and the generation kwargs) that v4 dumps into to_dict() but v5 moved out to GenerationConfig. With setup.cfg pinning transformers>=4.57.3,<6.0.0, every format's HF import failed on v4 with "unknown key 'torchscript'". Union the static allowlist with the live PretrainedConfig().to_dict() key set — every key a bare PretrainedConfig carries is generic metadata, never architecture — so the check adapts across the whole supported range instead of pinning a version-specific set. Co-Authored-By: Claude Opus 4.8 (1M context) --- fast_llm/engine/checkpoint/huggingface.py | 25 +++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index c055a7f2c..d3b1b7134 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -1,4 +1,5 @@ import abc +import functools import json import pathlib import shutil @@ -121,9 +122,12 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: }, ) - # Top-level HF metadata keys that are always permitted, regardless of the converter tree. - # Covers transformers' generic ``PretrainedConfig`` fields (always present after ``to_dict()``) - # plus a handful of widely-shared metadata that Fast-LLM intentionally does not store. + # HF metadata keys that are always permitted, regardless of the converter tree. The generic + # ``PretrainedConfig`` fields are added dynamically (see :meth:`_hf_metadata_allowlist`) because the + # exact set drifts across the supported transformers range — e.g. the generation kwargs and + # ``torchscript`` that v4 dumps into ``to_dict()`` were moved out to ``GenerationConfig`` in v5. This + # static set covers the widely-shared metadata that Fast-LLM intentionally does not store but that a + # bare ``PretrainedConfig`` does not carry (model-specific defaults like ``max_position_embeddings``). _HF_METADATA_ALLOWLIST: typing.ClassVar[frozenset[str]] = frozenset( { # transformers PretrainedConfig @@ -156,6 +160,19 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: } ) + @classmethod + @functools.cache + def _hf_metadata_allowlist(cls) -> frozenset[str]: + """Static allowlist unioned with the live ``PretrainedConfig`` field set. + + Every key a bare ``PretrainedConfig`` carries is generic transformers metadata, never + architecture, so deriving them from the installed transformers keeps the coverage check correct + across the supported version range instead of hard-coding a version-specific set. + """ + import transformers + + return cls._HF_METADATA_ALLOWLIST | frozenset(transformers.PretrainedConfig().to_dict()) + @classmethod def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None: """Run the HF-side coverage check at the import boundary. @@ -163,7 +180,7 @@ def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None: Subclasses that override :meth:`_import_config` should call this explicitly to keep the check active. """ - cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) + cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._hf_metadata_allowlist()) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: From 07d4f46ba9cb7558913d6d0f8f044e9a9c7452bc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 16:57:29 -0400 Subject: [PATCH 5/7] Skip convert-group tests when transformers lacks the format's HF config class Conversion tests instantiate the format's transformers config class; on supported-but-older transformers builds the class may not exist (e.g. no Gemma 4 before v5), making test_conversion[gemma4] error with `module transformers has no attribute Gemma4TextConfig`. Mirror the existing requires_cuda env-skip: in testing_group_enabled, skip the convert group when the class can't be imported, so the suite stays green across the supported transformers range (>=4.57,<6.0). Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/utils/model_configs.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 77828e0bd..2d4e2f5d4 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -1111,6 +1111,20 @@ def model_testing_config(request) -> ModelTestingConfig: return MODEL_CONFIGS[request.param] +@functools.cache +def _hf_config_class_available(checkpoint_format: type[CheckpointFormat]) -> bool: + """Whether the installed transformers provides the format's HF config class. + + Conversion tests need it; older-but-supported transformers builds may lack a recent model + (e.g. no Gemma 4 before transformers v5), in which case the convert group skips rather than errors. + """ + try: + checkpoint_format.get_handler_class().get_transformers_configuration_class() + except (ImportError, AttributeError): + return False + return True + + def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slow: bool, show_skipped: bool) -> bool: if "model_testing_group" in item.keywords: assert hasattr(item, "callspec") and "model_testing_config" in item.callspec.params, item.nodeid @@ -1120,6 +1134,15 @@ def testing_group_enabled(item: pytest.Function, skip_slow: bool, skip_extra_slo if model_config.requires_cuda and not torch.cuda.is_available(): item.add_marker(pytest.mark.skip(reason=f"Cuda not available.")) for group in groups: + if ( + group == ModelTestingGroup.convert + and model_config.checkpoint_format is not None + and not _hf_config_class_available(model_config.checkpoint_format) + ): + item.add_marker( + pytest.mark.skip(reason=f"transformers build lacks the HF config class for {model_testing_config}") + ) + continue action = model_config.groups.get(group, ModelTestingGroupAction.unimportant) if action == ModelTestingGroupAction.main: pass From d253886fba70295a737630b410b918a1361d30d3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 18:07:18 -0400 Subject: [PATCH 6/7] Trim static HF allowlist entries now covered by the dynamic PretrainedConfig union MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The static _HF_METADATA_ALLOWLIST kept 13 generic fields (architectures, model_type, dtype, …) that a bare PretrainedConfig().to_dict() carries on every supported transformers version, so the dynamic union already covers them. Drop those; keep only what the union misses across the range: auto_map / torch_dtype / use_cache (absent from a bare config on both v4 and v5), the token ids (absent from a bare v5 config), and the model-specific init/pretraining metadata. Behavior-preserving. Co-Authored-By: Claude Opus 4.8 (1M context) --- fast_llm/engine/checkpoint/huggingface.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index d3b1b7134..7fbc42afd 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -130,24 +130,12 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: # bare ``PretrainedConfig`` does not carry (model-specific defaults like ``max_position_embeddings``). _HF_METADATA_ALLOWLIST: typing.ClassVar[frozenset[str]] = frozenset( { - # transformers PretrainedConfig - "_name_or_path", - "architectures", + # transformers metadata Fast-LLM does not store that a bare ``PretrainedConfig().to_dict()`` + # omits across the supported range (so the dynamic union would miss them). "auto_map", - "chunk_size_feed_forward", - "dtype", - "id2label", - "is_encoder_decoder", - "label2id", - "model_type", - "output_attentions", - "output_hidden_states", - "problem_type", - "return_dict", "torch_dtype", - "transformers_version", "use_cache", - # Token ids — generation/inference, not architecture. + # Token ids — generation/inference, not architecture (a bare v5 config omits these). "bos_token_id", "decoder_start_token_id", "eos_token_id", From 4b1f76945708b21963199ddaf0bfa3f0c2f5edb1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 29 May 2026 18:27:17 -0400 Subject: [PATCH 7/7] Drop consumer reference from RepeatWeightConverter docstring The example named a downstream consumer (prediction-distance head stack); the preceding sentence already conveys the generic intent (runtime count drives the repeat). Per the style guide, explanatory text should not reference specific consumers. Co-Authored-By: Claude Opus 4.8 (1M context) --- fast_llm/engine/checkpoint/external.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index 285c8fe19..0de823ca8 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -1149,8 +1149,7 @@ class RepeatWeightConverter(WeightConverter): Unlike :class:`BlockSequenceWeightConverter` — which fans a converter over a materialized per-position config list — every iteration recurses into the *same* sub-config; only the emitted prefixes change - with the index. Used where a runtime count, not a block list, drives the repeat (e.g. a stack of - prediction-distance heads all sharing one block / normalization config). + with the index. Used where a runtime count, not a block list, drives the repeat. ``count`` and ``sub_config`` are resolved from the live section config. ``fast_llm_prefix`` and ``hf_prefix`` map a 0-based iteration index to the section-relative prefixes; the two need not share