diff --git a/fast_llm/engine/checkpoint/external.py b/fast_llm/engine/checkpoint/external.py index d78e6b405..c913859fe 100644 --- a/fast_llm/engine/checkpoint/external.py +++ b/fast_llm/engine/checkpoint/external.py @@ -543,21 +543,33 @@ def _consumed_hf_paths(self) -> frozenset[tuple[str, ...]]: class ConfigSectionConverter(abc.ABC): """Base class for converting one Fast-LLM ``Config`` class ↔ one HF dict subtree. - Subclasses declare the conversion via ``_create_config_converters``. Format-specific cross-field - invariants go on the ``_validate_export`` hook. The weight side is imperative — concrete subclasses - provide a ``get_converters`` classmethod that emits :class:`WeightConverter` instances. + Subclasses declare the conversion via ``_create_config_converters`` (config side) and + ``_create_weight_converters`` (weight side). Format-specific cross-field invariants go on the + ``_validate_export`` hook. Subclasses that participate in :class:`DispatchConfigConverter` set ``hf_type_name`` to the discriminator value used by the HF format (e.g. ``"attention"``, ``"mamba"``). .. warning:: - :meth:`_create_config_converters` is ``@functools.cache``\\ d on the base class. Subclasses that override - it must return a *fresh* dict (idiomatically ``{**super()._create_config_converters(), ...}``); mutating - the parent's returned dict in place would corrupt the cache entry for every subsequent caller. + Both ``_create_config_converters`` and ``_create_weight_converters`` are ``@functools.cache``\\ d on + the base class. Subclasses that override them must return a *fresh* dict (idiomatically + ``{**super()._create_..._converters(), ...}``); mutating the parent's returned dict in place would + corrupt the cache entry for every subsequent caller. """ fast_llm_config_class: typing.ClassVar[type[Config]] hf_type_name: typing.ClassVar[str | None] = None + weight_only: typing.ClassVar[bool] = False + """When ``True``, the section contributes weight conversions only — its config side is owned by an + ancestor (typically via :class:`CustomConfigConverter` with ``fast_llm_recurses=True``). + :meth:`_create_config_converters` defaults to no declarations and :meth:`check_architecture_coverage` + short-circuits, so the section does not need to claim its own architecture leaves. + + Sections whose ancestor isn't a recursive :class:`CustomConfigConverter` (e.g. Apriel/Apriel2's + type-dispatched blocks) handle the same situation through their own structure — the dispatching + primitive (Nested/Dispatch/TypedDictContainer) claims the subtree recursively — and don't need this + flag. + """ @classmethod @functools.cache @@ -567,9 +579,24 @@ def _create_config_converters(cls) -> dict[str, ConfigConverter]: Cached per class — declarations are immutable and depend only on ``cls``. Subclasses must build and return a *fresh* dict (idiomatically ``{**super()._create_config_converters(), ...}``); mutating the returned dict in place would corrupt the parent's cache entry for every subsequent caller. + Weight-only sections (``weight_only=True``) inherit the empty default. """ + if cls.weight_only: + return {} raise NotImplementedError + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, "WeightConverter"]: + """Return weight-conversion declarations keyed by stable string name. + + Same shape and caching rules as :meth:`_create_config_converters`. Section-relative names; the + walker (:meth:`emit_weight_converters`) prepends the section's full ``(fast_llm_prefix, hf_prefix)`` + pair as it descends. Defaults to no declarations — sections that don't own any weights leave this + unoverridden. + """ + return {} + @classmethod def _validate_export(cls, config: Config) -> None: """Hook for format-specific export-time validation. Default no-op. @@ -606,6 +633,32 @@ def import_config(cls, hf_dict: dict) -> dict: out = {"type": fast_llm_type, **out} return out + @classmethod + def emit_weight_converters( + cls, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config | None = None, + ) -> list["WeightConverter"]: + """Walk this section's weight declarations against ``config`` into a flat list of fully-qualified + :class:`WeightConverter` instances. + + Each declaration in :meth:`_create_weight_converters` returns one or more ``WeightConverter`` leaves via + its :meth:`WeightConverter._emit` hook. Structural primitives (Nested, BlockSequence) recurse into + sub-section converters; leaves return a single prefixed copy of themselves. ``root_config`` carries the + top-level config through the recursion for primitives whose behaviour depends on it (e.g. + :class:`OutputProjectionWeightConverter` consults ``root_config.tied_embedding_weight``); the walker + seeds it from ``config`` on the outermost call. + """ + if root_config is None: + root_config = config + out: list["WeightConverter"] = [] + for declaration in cls._create_weight_converters().values(): + out.extend(declaration._emit(config, fast_llm_prefix, hf_prefix, root_config=root_config)) + return out + @classmethod @functools.cache def _consumed_hf_paths(cls) -> frozenset[tuple[str, ...]]: @@ -679,7 +732,12 @@ def check_architecture_coverage(cls, config: Config) -> None: Invoked from a test fixture (``tests/models/test_converters.py``) — not from the production export/import paths. Architecture coverage is a structural invariant of the converter declarations, so it only needs to hold once per (converter, config-class) pair, not on every save. + + Short-circuits for weight-only sections (``weight_only=True``): the config side is owned by an + ancestor declaration, so the architecture leaves are claimed there. """ + if cls.weight_only: + return Assert.is_(type(config), cls.fast_llm_config_class) declarations = cls._create_config_converters() explicit_paths: set[tuple[str, ...]] = set() @@ -715,7 +773,34 @@ def check_architecture_coverage(cls, config: Config) -> None: ) +def prepend_prefix(prefix: str, names: tuple[str, ...]) -> tuple[str, ...]: + """Prepend ``prefix`` to each name. Empty ``prefix`` is a no-op; empty ``names`` stays empty.""" + if not prefix: + return names + return tuple(f"{prefix}.{name}" for name in names) + + +def join_prefix(parent: str, own: str) -> str: + """Join two dot-separated prefixes, tolerating either being empty. + + Structural primitives (Nested/BlockSequence/Dispatch/TypedDict) call this when building the + descended ``(fast_llm_prefix, hf_prefix)`` for the recursive call — either side can legitimately + be empty (the section converter sits at the HF root, or the declaration's own prefix is empty). + """ + if parent and own: + return f"{parent}.{own}" + return parent or own + + class WeightConverter: + """Leaf weight-conversion declaration / emitted instance. + + As a declaration in :meth:`ConfigSectionConverter._create_weight_converters`, the ``fast_llm_name`` and + ``export_name`` are *section-relative*; the walker constructs a fully-qualified emitted copy by prepending + the section prefixes via :meth:`_emit`. Subclasses that need extra construction context (e.g. capturing + a sub-config for use inside ``export_weight``/``import_weight``) override :meth:`_emit` accordingly. + """ + def __init__( self, fast_llm_name: str | tuple[str, ...], @@ -736,54 +821,468 @@ def import_weight( ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: return weight + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list["WeightConverter"]: + """Return a fully-qualified emitted copy of this leaf. + + Subclasses that capture extra construction state (e.g. :class:`NestedWeightConverter` / + :class:`LinearWeightConverter` holding a sub-converter class or a callable prefix) override + this hook to pass that state through to the emitted output. + """ + return [ + type(self)( + prepend_prefix(fast_llm_prefix, self.fast_llm_name), + prepend_prefix(hf_prefix, self.export_name), + config, + ) + ] + -class IgnoreImportWeightConverter(WeightConverter): - def __post_init__(self): - Assert.eq(len(self.fast_llm_name), 0) - Assert.gt(len(self.export_name), 0) +class SplitWeightConverter(WeightConverter): + """Split a merged tensor evenly across the listed export names; concatenate on import.""" def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - raise RuntimeError( - f"IgnoreImportWeightConverter should not be used for export: {self.fast_llm_name}, {self.export_name}" - ) + (merged_weight,) = weight + return tuple(merged_weight[:].chunk(len(self.export_name))) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return () + return (torch.cat([weight_[:] for weight_ in weight]),) + +class TransposeSplitWeightConverter(WeightConverter): + """Split a merged weight across the last dim with an additional transpose. -class IgnoreExportWeightConverter(WeightConverter): - def __post_init__(self): - Assert.gt(len(self.fast_llm_name), 0) - Assert.eq(len(self.export_name), 0) + The transpose only matters when the source tensor's last two dims differ; for 1-D biases and the + trivial single-name case it degenerates to a passthrough split. + """ def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return () + (merged_weight,) = weight + return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - raise RuntimeError( - f"IgnoreExportWeightConverter should not be used for import: {self.fast_llm_name}, {self.export_name}" - ) + merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) + return (merged_weight.t().contiguous(),) -class SplitWeightConverter(WeightConverter): +class KeyValueWeightConverter(WeightConverter): + """Pack/unpack a fused key-value tensor across the two HF names. + + Fast-LLM packs key/value as a single concatenated tensor; HF stores them as two siblings + (``k_proj`` / ``v_proj``). The same split applies to the 1-D bias. + """ + def export_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(merged_weight[:].chunk(len(self.export_name))) + (key_value,) = weight + key, value = key_value[:].chunk(2) + return key, value def import_weight( self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return (torch.cat([weight_[:] for weight_ in weight]),) + key, value = weight + key_value = torch.cat([key[:], value[:]]) + return (key_value,) + + +class PatchEmbeddingWeightConverter(WeightConverter): + """Reshape a vision patch-embedding weight from Fast-LLM's flat ``(out, channels*h*w)`` shape to HF's + ``(out, channels, h, w)`` (and back). Requires a config exposing ``input_channels``/``patch_height``/ + ``patch_width`` via the constructor's ``config`` argument.""" + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return tuple( + weight_[:].view( + *weight_[:].shape[:-1], + self._config.input_channels, + self._config.patch_height, + self._config.patch_width, + ) + for weight_ in weight + ) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + return tuple( + weight_[:].view( + *weight_[:].shape[:-3], + self._config.input_channels * self._config.patch_height * self._config.patch_width, + ) + for weight_ in weight + ) + + +class OutputProjectionWeightConverter(WeightConverter): + """Marker for the LM-head output projection (typically ``head.output_weights`` ↔ ``lm_head.weight``). + + When the root config has ``tied_embedding_weight=True``, the walker drops this declaration entirely — + HF stores tied embeddings as just ``embed_tokens.weight`` with no separate ``lm_head.weight``. + """ + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + if root_config.tied_embedding_weight: + return [] + return super()._emit(config, fast_llm_prefix, hf_prefix, root_config=root_config) + + +class NestedWeightConverter(WeightConverter): + """Recurse into a sub-section's weight declarations. + + The sub-section's config is read by chained ``getattr`` from ``config`` via ``config_attr``: + + * ``None`` (default) — single attribute named after ``fast_llm_prefix``. + * single string — single attribute (e.g. one ``normalization`` config feeding two state-dict + prefixes). + * tuple of strings — chained descent (e.g. ``("dense", "pre_norm")`` ↦ ``config.dense.pre_norm``). + + The walker descends into ``sub_converter_class._create_weight_converters()`` with extended prefixes. + Mirrors :class:`NestedConfigConverter` on the config side. + + ``optional=True`` skips the recursion when the resolved sub-config is ``None`` — for optional + architecture sub-sections. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + sub_converter_class: type["ConfigSectionConverter"], + *, + config_attr: str | tuple[str, ...] | None = None, + optional: bool = False, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._sub_converter_class = sub_converter_class + if config_attr is None: + self._config_attrs: tuple[str, ...] = (fast_llm_prefix,) + elif isinstance(config_attr, str): + self._config_attrs = (config_attr,) + else: + self._config_attrs = config_attr + self._optional = optional + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + sub_config: typing.Any = config + for attr in self._config_attrs: + sub_config = getattr(sub_config, attr) + if self._optional and sub_config is None: + return [] + return self._sub_converter_class.emit_weight_converters( + sub_config, + join_prefix(fast_llm_prefix, self._fast_llm_prefix), + join_prefix(hf_prefix, self._hf_prefix), + root_config=root_config, + ) + + +def _expand_block_sequence(block_sequence: Config) -> list[Config]: + """Materialize a ``Fixed``/``PatternBlockSequenceConfig`` into a per-position list of block configs. + + ``FixedBlockSequenceConfig``: single repeated block. ``PatternBlockSequenceConfig``: per-position + blocks indexed via ``decoder.expanded_pattern``. + """ + # Lazy import to keep external.py free of layers/ dependencies. + from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig + + if isinstance(block_sequence, FixedBlockSequenceConfig): + return [block_sequence.block] * block_sequence.num_blocks + if isinstance(block_sequence, PatternBlockSequenceConfig): + return [block_sequence.blocks[name] for name in block_sequence.expanded_pattern] + raise NotImplementedError(type(block_sequence).__name__) + + +class BlockSequenceWeightConverter(WeightConverter): + """Fan a single block converter across every position in a block sequence reached via attribute access. + + Reads the block sequence from ``getattr(parent, config_attr)`` (defaults to ``fast_llm_prefix``). + For per-position type dispatch (different mixer types per layer) use + :class:`DispatchBlockSequenceWeightConverter` instead; when the section config IS the block sequence + use :class:`SelfBlockSequenceWeightConverter`. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + block_converter_class: type["ConfigSectionConverter"], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._block_converter_class = block_converter_class + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + fast_llm_root = join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = join_prefix(hf_prefix, self._hf_prefix) + out: list[WeightConverter] = [] + for index, block in enumerate(_expand_block_sequence(getattr(config, self._config_attr))): + out += self._block_converter_class.emit_weight_converters( + block, + f"{fast_llm_root}.{index}", + f"{hf_root}.{index}", + root_config=root_config, + ) + return out + + +class DispatchBlockSequenceWeightConverter(WeightConverter): + """Fan a per-position-dispatched block converter across every position in a block sequence. + + Each position's block config is matched against ``dispatch_registry`` keys by its mixer type + (``type(block.mixer)``), allowing hybrid sequences where different layers use different block + converters. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + dispatch_registry: dict[type[Config], type["ConfigSectionConverter"]], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._dispatch_registry = dispatch_registry + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + fast_llm_root = join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = join_prefix(hf_prefix, self._hf_prefix) + out: list[WeightConverter] = [] + for index, block in enumerate(_expand_block_sequence(getattr(config, self._config_attr))): + block_class = self._dispatch_registry[type(block.mixer)] + out += block_class.emit_weight_converters( + block, + f"{fast_llm_root}.{index}", + f"{hf_root}.{index}", + root_config=root_config, + ) + return out + + +class SelfBlockSequenceWeightConverter(WeightConverter): + """Fan a single block converter across the section config when *the section IS the block sequence*. + + Used when the declaring section's ``fast_llm_config_class`` is itself a ``FixedBlockSequenceConfig`` + or ``PatternBlockSequenceConfig``. Weights land directly under the section's outer prefixes. + """ + + def __init__(self, block_converter_class: type["ConfigSectionConverter"]): + super().__init__((), ()) + self._block_converter_class = block_converter_class + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + out: list[WeightConverter] = [] + for index, block in enumerate(_expand_block_sequence(config)): + out += self._block_converter_class.emit_weight_converters( + block, + f"{fast_llm_prefix}.{index}", + f"{hf_prefix}.{index}", + root_config=root_config, + ) + return out + + +class DispatchWeightConverter(WeightConverter): + """Dispatch a single sub-section converter based on the live config's runtime type. + + Reads ``getattr(config, config_attr)`` (defaults to ``fast_llm_prefix``), looks up its type in + ``registry``, and recurses into that ConfigSectionConverter with the standard extended prefixes. + Mirrors :class:`DispatchConfigConverter` on the config side. Used when a single attribute holds one + of several alternative configs. + + ``hf_prefix_overrides`` lets individual branches replace the shared ``hf_prefix`` (e.g. when one + branch flat-merges into the parent root while siblings nest under a sub-prefix). + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + registry: dict[type[Config], type["ConfigSectionConverter"]], + *, + config_attr: str | None = None, + hf_prefix_overrides: dict[type[Config], str] | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._registry = registry + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + self._hf_prefix_overrides = hf_prefix_overrides or {} + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + sub_config = getattr(config, self._config_attr) + sub_type = type(sub_config) + sub_class = self._registry[sub_type] + own_hf_prefix = self._hf_prefix_overrides.get(sub_type, self._hf_prefix) + return sub_class.emit_weight_converters( + sub_config, + join_prefix(fast_llm_prefix, self._fast_llm_prefix), + join_prefix(hf_prefix, own_hf_prefix), + root_config=root_config, + ) + + +class TypedDictWeightConverter(WeightConverter): + """Per-key dispatch over a ``dict[str, Config]`` attribute. + + For each entry, looks up its type in ``registry`` and recurses into that converter with names + ``{fast_llm_prefix}.{key}`` / ``{hf_prefix}.{key}``. Mirrors + :class:`TypedDictContainerConfigConverter` on the config side. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str, + registry: dict[type[Config], type["ConfigSectionConverter"]], + *, + config_attr: str | None = None, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._registry = registry + self._config_attr = config_attr if config_attr is not None else fast_llm_prefix + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + attr_dict = getattr(config, self._config_attr) + fast_llm_root = join_prefix(fast_llm_prefix, self._fast_llm_prefix) + hf_root = join_prefix(hf_prefix, self._hf_prefix) + out: list[WeightConverter] = [] + for name, sub_config in attr_dict.items(): + sub_class = self._registry[type(sub_config)] + out += sub_class.emit_weight_converters( + sub_config, + join_prefix(fast_llm_root, name), + join_prefix(hf_root, name), + root_config=root_config, + ) + return out + + +class LinearWeightConverter(WeightConverter): + """Bundle a linear layer's ``.weight`` and (conditionally) ``.bias`` declarations into one entry. + + Bias presence is resolved at emission time from the live section config: ``bias_fn`` is either a bool + literal (always / never) or a callable returning a bool. The default reads + ``config.add_linear_biases``; sections with per-layer overrides pass a lambda that resolves the + override. + + ``transform`` selects the leaf class for both weight and bias: :class:`WeightConverter` for plain + rename (the default), :class:`SplitWeightConverter` for fused → split, :class:`KeyValueWeightConverter` + for fused KV → separate K/V, :class:`TransposeSplitWeightConverter` for a transposed split. + """ + + def __init__( + self, + fast_llm_prefix: str, + hf_prefix: str | tuple[str, ...] | typing.Callable[[Config], str | tuple[str, ...]], + *, + transform: type[WeightConverter] = WeightConverter, + bias_fn: bool | typing.Callable[[Config], bool] = lambda c: c.add_linear_biases, + ): + super().__init__((), ()) + self._fast_llm_prefix = fast_llm_prefix + self._hf_prefix = hf_prefix + self._transform = transform + self._bias_fn = bias_fn + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + resolved = self._hf_prefix(config) if callable(self._hf_prefix) else self._hf_prefix + hf_prefixes: tuple[str, ...] = (resolved,) if isinstance(resolved, str) else tuple(resolved) + weight_fast_llm = prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.weight",)) + weight_hf = prepend_prefix(hf_prefix, tuple(f"{p}.weight" for p in hf_prefixes)) + emitted: list[WeightConverter] = [self._transform(weight_fast_llm, weight_hf, config)] + has_bias = self._bias_fn if isinstance(self._bias_fn, bool) else self._bias_fn(config) + if has_bias: + bias_fast_llm = prepend_prefix(fast_llm_prefix, (f"{self._fast_llm_prefix}.bias",)) + bias_hf = prepend_prefix(hf_prefix, tuple(f"{p}.bias" for p in hf_prefixes)) + emitted.append(self._transform(bias_fast_llm, bias_hf, config)) + return emitted class ExternalStateDictCheckpointHandler(StateDictCheckpointHandler): diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index 9074e72fc..c055a7f2c 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -23,31 +23,27 @@ import transformers -class HuggingFaceBaseModelConverter: - @classmethod - @abc.abstractmethod - def import_config(cls, config: dict) -> dict: - pass +class HuggingFaceBaseModelConverter(ConfigSectionConverter): + """Section converter for a full HF model root. Inherits the declarative config-side machinery from + :class:`ConfigSectionConverter` (``import_config`` / ``export_config`` driven by + ``_create_config_converters``) and adds the weight-side ``get_converters`` aggregation that the + enclosing :class:`HuggingfaceStateDictCheckpointHandler` invokes. + """ @classmethod - @abc.abstractmethod - def export_config(cls, config: BaseModelConfig) -> dict: - pass + def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]: + """Default: walk the section's weight declarations from the root. - @classmethod - @abc.abstractmethod - def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]: - pass + Subclasses override when a section needs cross-section state from the full base-model config + (typically when an extension point on the head must read from a sibling section). + """ + return cls.emit_weight_converters(config, "", "") class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC): architecture: typing.ClassVar[str] base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]] - def __init__(self, model: "FastLLMModel"): - self._exported_config = self._export_config(model.config) - super().__init__(model) - @classmethod @abc.abstractmethod def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]: @@ -164,13 +160,10 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]: def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None: """Run the HF-side coverage check at the import boundary. - Skips silently when the format's base-model converter isn't a :class:`ConfigSectionConverter` - (e.g. multimodal aggregators built on top of imperative ``HuggingFaceBaseModelConverter``). Subclasses that override :meth:`_import_config` should call this explicitly to keep the check active. """ - if issubclass(cls.base_model_converter_class, ConfigSectionConverter): - cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) + cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST) @classmethod def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: @@ -180,7 +173,7 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig: return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)}) def _create_weight_converters(self) -> list[WeightConverter]: - return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config) + return self.base_model_converter_class.get_converters(self._model.config.base_model) def _load_weights( self, config: CheckpointLoadConfig, device diff --git a/fast_llm/models/gpt/conversion/apriel.py b/fast_llm/models/gpt/conversion/apriel.py index c09937aaa..103c42dae 100644 --- a/fast_llm/models/gpt/conversion/apriel.py +++ b/fast_llm/models/gpt/conversion/apriel.py @@ -1,3 +1,4 @@ +import functools import math import typing @@ -10,7 +11,9 @@ ConfigSectionConverter, CustomConfigConverter, DefaultConfigConverter, + DispatchBlockSequenceWeightConverter, IgnoredConfigConverter, + LinearWeightConverter, RenameConfigConverter, WeightConverter, _get_attr_path, @@ -18,19 +21,13 @@ ) from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.block.config import FixedBlockSequenceConfig, PatternBlockSequenceConfig -from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig -from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig +from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import AprielHybridSSMCheckpointFormat -from fast_llm.models.gpt.conversion.llama import ( - effective_bias, - get_parameter_converter, - get_weight_and_bias_converters, -) +from fast_llm.models.gpt.conversion.llama import effective_bias from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) from fast_llm.utils import Assert, safe_merge_dicts @@ -132,56 +129,23 @@ def _validate_export(cls, config: MambaConfig) -> None: Assert.incl(config.output_layer.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, - config: MambaConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - # TODO: Conv - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj", - f"{hf_prefix}.in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_in_proj", - f"{hf_prefix}.dt_in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_proj", - f"{hf_prefix}.dt_proj", - effective_bias(config.dt_layer, config.add_linear_biases), - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.conv1d", - effective_bias(config.convolution_layer, config.add_linear_biases), - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.D", - f"{hf_prefix}.D", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "in_proj": LinearWeightConverter("in_proj", "in_proj"), + "dt_in_proj": LinearWeightConverter("dt_in_proj", "dt_in_proj"), + "dt_proj": LinearWeightConverter( + "dt_proj", "dt_proj", bias_fn=lambda c: effective_bias(c.dt_layer, c.add_linear_biases) + ), + "convolution": LinearWeightConverter( + "convolution", + "conv1d", + bias_fn=lambda c: effective_bias(c.convolution_layer, c.add_linear_biases), + ), + "A_log": WeightConverter("A_log", "A_log"), + "D": WeightConverter("D", "D"), + "out_proj": LinearWeightConverter("out_proj", "out_proj"), + } class GatedDeltaNetConverter(ConfigSectionConverter): @@ -223,55 +187,18 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: GatedDeltaNetConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_qkvz", - f"{hf_prefix}.in_proj_qkvz", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_ba", - f"{hf_prefix}.in_proj_ba", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.convolution", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - False, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # GDN has no ``add_linear_biases`` field; pass ``bias_fn=False`` so the default isn't consulted. + return { + "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=False), + "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=False), + "convolution": LinearWeightConverter("convolution", "convolution", bias_fn=False), + "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=False), + "A_log": WeightConverter("A_log", "A_log"), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "norm": LinearWeightConverter("norm", "norm", bias_fn=False), + } class KimiDeltaAttentionConverter(ConfigSectionConverter): @@ -317,103 +244,29 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: KimiDeltaAttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_proj", - f"{hf_prefix}.q_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_proj", - f"{hf_prefix}.k_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_proj", - f"{hf_prefix}.v_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_conv", - f"{hf_prefix}.q_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_conv", - f"{hf_prefix}.k_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_conv", - f"{hf_prefix}.v_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_a_proj", - f"{hf_prefix}.f_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_b_proj", - f"{hf_prefix}.f_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_a_proj", - f"{hf_prefix}.g_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_b_proj", - f"{hf_prefix}.g_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.beta_proj", - f"{hf_prefix}.beta_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.o_proj", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - False, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # KimiDeltaAttention has no linear biases. + proj_names = ( + "q_proj", + "k_proj", + "v_proj", + "q_conv", + "k_conv", + "v_conv", + "f_a_proj", + "f_b_proj", + "g_a_proj", + "g_b_proj", + "beta_proj", + "o_proj", + ) + return { + **{name: LinearWeightConverter(name, name, bias_fn=False) for name in proj_names}, + "A_log": WeightConverter("A_log", "A_log"), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "norm": LinearWeightConverter("norm", "norm", bias_fn=False), + } class AprielKimiDeltaAttentionBlockConverter(MistralBlockConverter): @@ -520,14 +373,14 @@ def _consumed_hf_paths(self) -> frozenset[tuple[str, ...]]: class AprielBlockConverter: - """Per-mixer-type Apriel block-converter registries plus a weight-side dispatch helper. + """Registry holder for Apriel's per-mixer-type block converters. ``layout_names`` maps the mixer-config classes that participate in Apriel's - ``hybrid_block_layout`` discriminator to their string layout names. ``_converter_classes`` - maps every mixer-config class whose weights can appear in an Apriel checkpoint to its block - converter — a superset of ``layout_names`` keys that adds ``KimiDeltaAttentionConfig`` for - weight-only coverage. :meth:`get_converters` picks the right block converter from - ``_converter_classes`` by the mixer's runtime type. + ``hybrid_block_layout`` discriminator to their string layout names — consumed by + :class:`ListDispatchConfigConverter` on the config side. ``_converter_classes`` maps every + mixer-config class whose weights can appear in an Apriel checkpoint to its block converter + (a superset of ``layout_names`` keys that adds ``KimiDeltaAttentionConfig`` for weight-only + coverage) — consumed by ``DispatchBlockSequenceWeightConverter`` on the weight side. """ layout_names: typing.ClassVar[dict[type[Config], str]] = { @@ -542,22 +395,6 @@ class AprielBlockConverter: GatedDeltaNetConfig: AprielGatedDeltaNetBlockConverter, } - @classmethod - def get_converters( - cls, - config: DecoderBlockConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return cls._converter_classes[type(config.mixer)].get_converters( - config, fast_llm_prefix, hf_prefix, drop_on_export=drop_on_export - ) - - -class AprielHeadConverter(MistralHeadConverter): - block_converter_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter - class AprielBaseModelConverter(MistralBaseModelConverter): """Section converter for the Apriel hybrid-SSM base model. @@ -567,10 +404,9 @@ class AprielBaseModelConverter(MistralBaseModelConverter): HF keys flat-merge into the parent HF root. """ - head_converter_class: typing.ClassVar[type[AprielHeadConverter]] = AprielHeadConverter # Distinct from the parent's ``block_converter_class`` (a single ``ConfigSectionConverter``); this # one holds the per-mixer-type dispatch registries that :class:`ListDispatchConfigConverter` and - # the weight-side loop below consume. + # :class:`DispatchBlockSequenceWeightConverter` consume. block_dispatcher_class: typing.ClassVar[type[AprielBlockConverter]] = AprielBlockConverter @classmethod @@ -587,23 +423,16 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - decoder = config.decoder - if isinstance(decoder, FixedBlockSequenceConfig): - per_position_blocks = [decoder.block] * decoder.num_blocks - elif isinstance(decoder, PatternBlockSequenceConfig): - per_position_blocks = [decoder.blocks[block_name] for block_name in decoder.expanded_pattern] - else: - raise NotImplementedError(type(decoder).__name__) - converters = [*cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")] - for block_index, block_config in enumerate(per_position_blocks): - converters += cls.block_dispatcher_class.get_converters( - block_config, - f"decoder.{block_index}", - f"model.layers.{block_index}", - ) - converters += cls.head_converter_class.get_converters(config, exported_config) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + **super()._create_weight_converters(), + "decoder": DispatchBlockSequenceWeightConverter( + "decoder", + "model.layers", + cls.block_dispatcher_class._converter_classes, + ), + } class AprielHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/apriel2.py b/fast_llm/models/gpt/conversion/apriel2.py index d5af6b572..956225b40 100644 --- a/fast_llm/models/gpt/conversion/apriel2.py +++ b/fast_llm/models/gpt/conversion/apriel2.py @@ -1,5 +1,6 @@ """Apriel2 text-only checkpoint format converter.""" +import functools import typing from transformers import PretrainedConfig @@ -7,18 +8,27 @@ from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantImportConfigConverter, CustomConfigConverter, DispatchConfigConverter, + DispatchWeightConverter, IgnoredConfigConverter, + KeyValueWeightConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, OptionalConfigConverter, + OutputProjectionWeightConverter, RenameConfigConverter, + SplitWeightConverter, + TransposeSplitWeightConverter, TypedDictContainerConfigConverter, + TypedDictWeightConverter, WeightConverter, ) -from fast_llm.engine.checkpoint.huggingface import HuggingfaceStateDictCheckpointHandler +from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig @@ -35,15 +45,10 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Apriel2TextCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( - KeyValueWeightConverter, LlamaEmbeddingsConverter, LlamaNormalizationConverter, - MLPLayer2Converter, - SplitWeightConverter, assert_no_peft, effective_bias, - get_parameter_converter, - get_weight_and_bias_converters, ) from fast_llm.models.gpt.model import GPTModel from fast_llm.utils import Assert, safe_merge_dicts @@ -177,42 +182,28 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - q_bias = effective_bias(config.query_layer, config.add_linear_biases) - k_bias = effective_bias(config.key_layer, config.add_linear_biases) - v_bias = effective_bias(config.value_layer, config.add_linear_biases) - o_bias = effective_bias(config.dense_layer, config.add_linear_biases) - # k_proj and v_proj are merged in Fast-LLM's key_value layer; treat as biased only if both sides agree. - kv_bias = k_bias and v_bias - - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - q_bias, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - kv_bias, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # Each linear layer carries its own ``bias.enabled`` override; the default falls back to the + # mixer-wide ``add_linear_biases`` via :func:`effective_bias`. ``key_value`` is biased only when + # both K and V agree (Fast-LLM packs them as a single tensor). + return { + "query": LinearWeightConverter( + "query", "q_proj", bias_fn=lambda c: effective_bias(c.query_layer, c.add_linear_biases) + ), + "key_value": LinearWeightConverter( + "key_value", + ("k_proj", "v_proj"), + transform=KeyValueWeightConverter, + bias_fn=lambda c: ( + effective_bias(c.key_layer, c.add_linear_biases) + and effective_bias(c.value_layer, c.add_linear_biases) + ), ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - o_bias, - drop_on_export=drop_on_export, + "dense": LinearWeightConverter( + "dense", "o_proj", bias_fn=lambda c: effective_bias(c.dense_layer, c.add_linear_biases) ), - ] + } def _apriel2_mamba_aux_export(config: MambaConfig) -> dict: @@ -274,55 +265,22 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: MambaConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj", - f"{hf_prefix}.in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_in_proj", - f"{hf_prefix}.dt_in_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dt_proj", - f"{hf_prefix}.dt_proj", - config.dt_layer.bias.enabled, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.conv1d", - config.convolution_layer.bias.enabled, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.D", - f"{hf_prefix}.D", - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # dt_proj and convolution read per-layer ``bias.enabled`` directly (no fallback to the mixer-wide + # flag — Apriel2's HF surfaces these biases via the dedicated ``dt_proj_bias`` / ``conv_bias`` + # auxiliary keys rather than via ``add_linear_biases``). + return { + "in_proj": LinearWeightConverter("in_proj", "in_proj"), + "dt_in_proj": LinearWeightConverter("dt_in_proj", "dt_in_proj"), + "dt_proj": LinearWeightConverter("dt_proj", "dt_proj", bias_fn=lambda c: c.dt_layer.bias.enabled), + "convolution": LinearWeightConverter( + "convolution", "conv1d", bias_fn=lambda c: c.convolution_layer.bias.enabled + ), + "A_log": WeightConverter("A_log", "A_log"), + "D": WeightConverter("D", "D"), + "out_proj": LinearWeightConverter("out_proj", "out_proj"), + } class Apriel2GatedDeltaNetConverter(ConfigSectionConverter): @@ -356,55 +314,19 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: GatedDeltaNetConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_qkvz", - f"{hf_prefix}.in_proj_qkvz", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.in_proj_ba", - f"{hf_prefix}.in_proj_ba", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.convolution", - f"{hf_prefix}.convolution", - config.convolution_layer.bias.enabled, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.out_proj", - f"{hf_prefix}.out_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "in_proj_qkvz": LinearWeightConverter("in_proj_qkvz", "in_proj_qkvz", bias_fn=False), + "in_proj_ba": LinearWeightConverter("in_proj_ba", "in_proj_ba", bias_fn=False), + "convolution": LinearWeightConverter( + "convolution", "convolution", bias_fn=lambda c: c.convolution_layer.bias.enabled + ), + "out_proj": LinearWeightConverter("out_proj", "out_proj", bias_fn=False), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "A_log": WeightConverter("A_log", "A_log"), + "norm": NestedWeightConverter("norm", "norm", LlamaNormalizationConverter, config_attr="normalization"), + } class Apriel2KimiDeltaAttentionConverter(ConfigSectionConverter): @@ -451,103 +373,28 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: KimiDeltaAttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_proj", - f"{hf_prefix}.q_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_proj", - f"{hf_prefix}.k_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_proj", - f"{hf_prefix}.v_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.q_conv", - f"{hf_prefix}.q_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.k_conv", - f"{hf_prefix}.k_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.v_conv", - f"{hf_prefix}.v_conv", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_a_proj", - f"{hf_prefix}.f_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.f_b_proj", - f"{hf_prefix}.f_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_a_proj", - f"{hf_prefix}.g_a_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.g_b_proj", - f"{hf_prefix}.g_b_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.beta_proj", - f"{hf_prefix}.beta_proj", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.o_proj", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.A_log", - f"{hf_prefix}.A_log", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.dt_bias", - f"{hf_prefix}.dt_bias", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm", - f"{hf_prefix}.norm", - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + proj_names = ( + "q_proj", + "k_proj", + "v_proj", + "q_conv", + "k_conv", + "v_conv", + "f_a_proj", + "f_b_proj", + "g_a_proj", + "g_b_proj", + "beta_proj", + "o_proj", + ) + return { + **{name: LinearWeightConverter(name, name, bias_fn=False) for name in proj_names}, + "A_log": WeightConverter("A_log", "A_log"), + "dt_bias": WeightConverter("dt_bias", "dt_bias"), + "norm": NestedWeightConverter("norm", "norm", LlamaNormalizationConverter, config_attr="normalization"), + } # Mixer dispatch registry — used inside StochasticMixer (no nested-stochastic) and at the block level. @@ -580,27 +427,9 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: StochasticMixerConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters = [] - for name, sub_mixer in config.mixers.items(): - converter_class = APRIEL2_LEAF_MIXER_REGISTRY.get(type(sub_mixer)) - if converter_class is None: - raise ValueError(f"Unknown sub-mixer type: {type(sub_mixer)}") - converters.extend( - converter_class.get_converters( - sub_mixer, - f"{fast_llm_prefix}.mixers.{name}", - f"{hf_prefix}.mixers.{name}", - drop_on_export=drop_on_export, - ) - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return {"mixers": TypedDictWeightConverter("mixers", "mixers", APRIEL2_LEAF_MIXER_REGISTRY)} # Block-level mixer registry includes StochasticMixer (which can wrap leaf mixers). @@ -693,48 +522,24 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: MLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - layer_1_bias = effective_bias(config.layer_1, config.add_linear_biases) - layer_2_bias = effective_bias(config.layer_2, config.add_linear_biases) - if config.gated: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), - layer_1_bias, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - f"{hf_prefix}.up_proj", - layer_1_bias, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - layer_2_bias, - MLPLayer2Converter, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``layer_1`` splits into ``(gate_proj, up_proj)`` when gated, but stays as a single ``up_proj`` + # otherwise. The transform and HF prefix both depend on the live config; resolve at emit time. + return { + "layer_1": LinearWeightConverter( + "layer_1", + lambda c: ("gate_proj", "up_proj") if c.gated else ("up_proj",), + transform=SplitWeightConverter, + bias_fn=lambda c: effective_bias(c.layer_1, c.add_linear_biases), + ), + "layer_2": LinearWeightConverter( + "layer_2", + "down_proj", + transform=TransposeSplitWeightConverter, + bias_fn=lambda c: effective_bias(c.layer_2, c.add_linear_biases), ), - ] + } class Apriel2BlockConverter(ConfigSectionConverter): @@ -775,52 +580,26 @@ def _validate_export(cls, config: DecoderBlockConfig) -> None: Assert.custom(lambda v: not v, config.output_scale.enabled) @classmethod - def get_converters( - cls, - config: DecoderBlockConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - mixer_converter_class = APRIEL2_BLOCK_MIXER_REGISTRY.get(type(config.mixer)) - if mixer_converter_class is None: - raise ValueError(f"Unknown mixer type: {type(config.mixer)}") - converters: list[WeightConverter] = list( - mixer_converter_class.get_converters( - config.mixer, - f"{fast_llm_prefix}.mixer", - f"{hf_prefix}.mixer", - drop_on_export=drop_on_export, - ) - ) - converters.extend( - Apriel2MLPConverter.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - f"{hf_prefix}.mlp", - drop_on_export=drop_on_export, - ) - ) - converters.extend( - [ - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.input_layernorm", - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.post_attention_layernorm", - drop_on_export=drop_on_export, - ), - ] - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": DispatchWeightConverter("mixer", "mixer", APRIEL2_BLOCK_MIXER_REGISTRY), + "mlp": NestedWeightConverter("mlp", "mlp", Apriel2MLPConverter), + # The two state-dict norms (norm_1/norm_2) share the block's single ``normalization`` config. + "norm_1": NestedWeightConverter( + "norm_1", "input_layernorm", LlamaNormalizationConverter, config_attr="normalization" + ), + "norm_2": NestedWeightConverter( + "norm_2", "post_attention_layernorm", LlamaNormalizationConverter, config_attr="normalization" + ), + } class Apriel2FixedDecoderConverter(ConfigSectionConverter): + """Config-only section: block fan-out lives on the base-model converter via + :class:`BlockSequenceWeightConverter`. This class exists for the config-side + :class:`DispatchConfigConverter` between Fixed and Pattern decoder shapes.""" + fast_llm_config_class = FixedBlockSequenceConfig hf_type_name = "fixed" block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @@ -832,26 +611,10 @@ def _create_config_converters(cls) -> dict: "block": NestedConfigConverter(("block",), cls.block_converter_class, hf_path=("block",)), } - @classmethod - def get_converters( - cls, - config: FixedBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export=drop_on_export, - ) - return converters - class Apriel2PatternDecoderConverter(ConfigSectionConverter): + """Config-only section; see :class:`Apriel2FixedDecoderConverter`.""" + fast_llm_config_class = PatternBlockSequenceConfig hf_type_name = "pattern" block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter @@ -868,25 +631,6 @@ def _create_config_converters(cls) -> dict: ), } - @classmethod - def get_converters( - cls, - config: PatternBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - block_config = config.blocks[config.pattern[block_index % len(config.pattern)]] - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export=drop_on_export, - ) - return converters - APRIEL2_DECODER_REGISTRY: dict[type[Config], type[ConfigSectionConverter]] = { FixedBlockSequenceConfig: Apriel2FixedDecoderConverter, @@ -894,16 +638,6 @@ def get_converters( } -def get_apriel2_decoder_converter( - decoder_config: FixedBlockSequenceConfig | PatternBlockSequenceConfig, -) -> type[ConfigSectionConverter]: - """Look up the Apriel2 per-shape decoder converter for a given decoder config instance.""" - converter_class = APRIEL2_DECODER_REGISTRY.get(type(decoder_config)) - if converter_class is None: - raise NotImplementedError(f"Unsupported decoder type: {type(decoder_config).__name__}") - return converter_class - - class Apriel2HeadConverter(ConfigSectionConverter): fast_llm_config_class = LanguageModelHeadConfig @@ -932,31 +666,21 @@ def _validate_export(cls, config: LanguageModelHeadConfig) -> None: Assert.is_(type(config.normalization), RMSNormalizationConfig) @classmethod - def get_converters( - cls, - config: LanguageModelHeadConfig, - exported_config: dict, - fast_llm_prefix: str, - ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.final_norm", - "model.norm", - ), - get_parameter_converter( - f"{fast_llm_prefix}.output_weights", - "lm_head.weight", - drop_on_import=exported_config.get("tie_word_embeddings", False), - drop_on_export=exported_config.get("tie_word_embeddings", False), + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.norm", cls.normalization_converter_class, config_attr="normalization" ), - ] + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } -class Apriel2BaseModelConverter(ConfigSectionConverter): +class Apriel2BaseModelConverter(HuggingFaceBaseModelConverter): fast_llm_config_class = GPTBaseModelConfig embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod @@ -986,14 +710,13 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: assert_no_peft(config) @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *get_apriel2_decoder_converter(config.decoder).get_converters( - config.decoder, "decoder", "model.decoder.blocks" - ), - *cls.head_converter_class.get_converters(config.head, exported_config, "head"), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", cls.block_converter_class), + "head": NestedWeightConverter("head", "", cls.head_converter_class), + } class Apriel2HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @@ -1039,7 +762,3 @@ def _export_config(cls, config: GPTModelConfig) -> dict[str, typing.Any]: def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: cls._check_hf_coverage(config) return {"base_model": cls.base_model_converter_class.import_config(config)} - - @classmethod - def _get_weight_converters(cls, config: GPTModelConfig, export_config: dict) -> list[WeightConverter]: - return cls.base_model_converter_class.get_converters(config.base_model, export_config) diff --git a/fast_llm/models/gpt/conversion/gemma4.py b/fast_llm/models/gpt/conversion/gemma4.py index 6cbb7898c..75d65911c 100644 --- a/fast_llm/models/gpt/conversion/gemma4.py +++ b/fast_llm/models/gpt/conversion/gemma4.py @@ -1,17 +1,27 @@ """Gemma4 checkpoint format converter.""" +import functools import typing +import torch + from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConfigSectionConverter, ConstantExportConfigConverter, CustomConfigConverter, + DispatchWeightConverter, IgnoredConfigConverter, + KeyValueWeightConverter, + LinearWeightConverter, + NestedWeightConverter, RenameConfigConverter, + SelfBlockSequenceWeightConverter, SplitWeightConverter, + TransposeSplitWeightConverter, WeightConverter, + join_prefix, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler from fast_llm.functional.config import ActivationType @@ -28,15 +38,12 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import Gemma4CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( - KeyValueWeightConverter, LlamaEmbeddingsConverter, LlamaHeadConverter, LlamaNormalizationConverter, - MLPLayer2Converter, - get_parameter_converter, - get_weight_and_bias_converters, ) from fast_llm.models.gpt.model import GPTModel +from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, safe_merge_dicts _SLIDING_ATTENTION = "sliding_attention" @@ -75,7 +82,72 @@ def import_weight(self, weight): return (w.permute(0, 2, 1).reshape(-1, w.shape[1]).contiguous(),) -class Gemma4AttentionConverter: +class _Gemma4BlockNorm2WeightConverter(WeightConverter): + """Dense Gemma4 blocks store the pre-MLP norm at ``norm_2`` (drawn from the block's main + ``normalization`` config). MoE blocks suppress this — the routed/dense branches inside the + hybrid MoE own their own pre/post norms. The conditional emit reads a sibling attribute + (``config.mlp``) and can't be expressed via :class:`NestedWeightConverter.optional`. + """ + + def __init__(self) -> None: + super().__init__((), ()) + + def _emit( + self, + config: Config, + fast_llm_prefix: str, + hf_prefix: str, + *, + root_config: Config, + ) -> list[WeightConverter]: + if isinstance(config.mlp, HybridMoEMLPConfig): + return [] + return LlamaNormalizationConverter.emit_weight_converters( + config.normalization, + join_prefix(fast_llm_prefix, "norm_2"), + join_prefix(hf_prefix, "pre_feedforward_layernorm"), + root_config=root_config, + ) + + +class _Gemma4SharedKeyValueWeightConverter(KeyValueWeightConverter): + """``shared_key_value=True`` Gemma4 attention: Fast-LLM's ``key_value`` is a single K-shaped + tensor (V is reused at runtime) and maps to a single HF ``k_proj`` — plain rename. Delegates to + :class:`KeyValueWeightConverter` (chunk/cat across K and V) when not shared. + """ + + _config: AttentionConfig + + def export_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + if self._config.shared_key_value: + return weight + return super().export_weight(weight) + + def import_weight( + self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] + ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: + if self._config.shared_key_value: + return weight + return super().import_weight(weight) + + +class Gemma4AttentionConverter(ConfigSectionConverter): + """Gemma4's attention helper: ``import_config`` / ``export_config`` take non-standard arguments + (sliding/full discrimination, twin block exports) and are invoked imperatively from + :class:`Gemma4BlockConverter`. Only the weight side fits the standard declarative shape — biases + are always disabled, query/key norms are emitted only when present, and the K/V layout collapses + to a single ``k_proj`` when ``shared_key_value`` is set. + + The config side is owned by :class:`Gemma4BaseModelConverter`'s ``decoder`` :class:`CustomConfigConverter` + (with ``fast_llm_recurses=True``); the ``weight_only`` flag below signals that — Gemma4's sliding/full + divergence prevents a uniform declarative shape per single block. + """ + + fast_llm_config_class = AttentionConfig + weight_only: typing.ClassVar[bool] = True + @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: eps = config["rms_norm_eps"] @@ -154,64 +226,28 @@ def export_config(cls, sliding_config: AttentionConfig, full_config: AttentionCo } @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - if config.shared_key_value: - # K=V: single k_proj reused as value; no v_proj in HF - kv_converters = get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - f"{hf_prefix}.k_proj", - False, - drop_on_export=drop_on_export, - ) - else: - kv_converters = get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - False, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, - ) - converters = [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - False, - drop_on_export=drop_on_export, - ), - *kv_converters, - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "query": LinearWeightConverter("query", "q_proj", bias_fn=False), + "key_value": LinearWeightConverter( + "key_value", + lambda c: "k_proj" if c.shared_key_value else ("k_proj", "v_proj"), + transform=_Gemma4SharedKeyValueWeightConverter, + bias_fn=False, ), - ] - if config.query_norm is not None: - converters += LlamaNormalizationConverter.get_converters( - config.query_norm, - f"{fast_llm_prefix}.query_norm", - f"{hf_prefix}.q_norm", - drop_on_export=drop_on_export, - ) - if config.key_norm is not None: - converters += LlamaNormalizationConverter.get_converters( - config.key_norm, - f"{fast_llm_prefix}.key_norm", - f"{hf_prefix}.k_norm", - drop_on_export=drop_on_export, - ) - # value_norm is FixedRMSNorm — no learnable weight to convert - return converters + "dense": LinearWeightConverter("dense", "o_proj", bias_fn=False), + # ``value_norm`` is :class:`FixedRMSNormConfig` (no learnable weight) — not declared. + "query_norm": NestedWeightConverter("query_norm", "q_norm", LlamaNormalizationConverter, optional=True), + "key_norm": NestedWeightConverter("key_norm", "k_norm", LlamaNormalizationConverter, optional=True), + } + +class Gemma4MLPConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. + fast_llm_config_class = MLPConfig + weight_only: typing.ClassVar[bool] = True -class Gemma4MLPConverter: @classmethod def import_config(cls, config: dict) -> dict: return { @@ -234,32 +270,23 @@ def export_config(cls, config: MLPConfig) -> dict: } @classmethod - def get_converters( - cls, - config: MLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), - False, - SplitWeightConverter, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter( + "layer_1", ("gate_proj", "up_proj"), transform=SplitWeightConverter, bias_fn=False ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - False, - MLPLayer2Converter, - drop_on_export=drop_on_export, + "layer_2": LinearWeightConverter( + "layer_2", "down_proj", transform=TransposeSplitWeightConverter, bias_fn=False ), - ] + } -class Gemma4MoEMLPConverter: +class Gemma4MoEMLPConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. + fast_llm_config_class = MoEMLPConfig + weight_only: typing.ClassVar[bool] = True + @classmethod def import_config(cls, config: dict) -> dict: eps = config["rms_norm_eps"] @@ -306,50 +333,23 @@ def export_config(cls, config: MoEMLPConfig, hidden_size: int) -> dict: } @classmethod - def get_converters( - cls, - config: MoEMLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters = [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.router", - f"{hf_prefix}.router.proj", - False, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.router_scale", - f"{hf_prefix}.router.scale", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.router_per_expert_scale", - f"{hf_prefix}.router.per_expert_scale", - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.layer_1.weight", - f"{hf_prefix}.experts.gate_up_proj", - Gemma4MoELayer1Converter, - config, - drop_on_export=drop_on_export, - ), - get_parameter_converter( - f"{fast_llm_prefix}.layer_2.weight", - f"{hf_prefix}.experts.down_proj", - Gemma4MoELayer2Converter, - config, - drop_on_export=drop_on_export, - ), - ] - # router.norm is FixedRMSNorm — no learnable weight to convert. - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``router.norm`` is :class:`FixedRMSNormConfig` (no learnable weight) — not declared. + return { + "router": LinearWeightConverter("router", "router.proj", bias_fn=False), + "router_scale": WeightConverter("router_scale", "router.scale"), + "router_per_expert_scale": WeightConverter("router_per_expert_scale", "router.per_expert_scale"), + "layer_1": Gemma4MoELayer1Converter("layer_1.weight", "experts.gate_up_proj"), + "layer_2": Gemma4MoELayer2Converter("layer_2.weight", "experts.down_proj"), + } + +class Gemma4HybridMoEMLPConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. + fast_llm_config_class = HybridMoEMLPConfig + weight_only: typing.ClassVar[bool] = True -class Gemma4HybridMoEMLPConverter: @classmethod def import_config(cls, config: dict) -> dict: eps = config["rms_norm_eps"] @@ -375,54 +375,44 @@ def export_config(cls, config: HybridMoEMLPConfig, hidden_size: int) -> dict: ) @classmethod - def get_converters( - cls, - config: HybridMoEMLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *Gemma4MLPConverter.get_converters( - config.dense, - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.mlp", - drop_on_export=drop_on_export, - ), - *Gemma4MoEMLPConverter.get_converters( - config.routed, - f"{fast_llm_prefix}.routed", - hf_prefix, - drop_on_export=drop_on_export, - ), - *LlamaNormalizationConverter.get_converters( - config.dense.pre_norm, - f"{fast_llm_prefix}.dense.pre_norm", - f"{hf_prefix}.pre_feedforward_layernorm", - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "dense": NestedWeightConverter("dense", "mlp", Gemma4MLPConverter), + # Routed branch lives at the block's HF root (sibling of ``mlp.<...>``). + "routed": NestedWeightConverter("routed", "", Gemma4MoEMLPConverter), + "dense_pre_norm": NestedWeightConverter( + "dense.pre_norm", + "pre_feedforward_layernorm", + LlamaNormalizationConverter, + config_attr=("dense", "pre_norm"), ), - *LlamaNormalizationConverter.get_converters( - config.dense.post_norm, - f"{fast_llm_prefix}.dense.post_norm", - f"{hf_prefix}.post_feedforward_layernorm_1", - drop_on_export=drop_on_export, + "dense_post_norm": NestedWeightConverter( + "dense.post_norm", + "post_feedforward_layernorm_1", + LlamaNormalizationConverter, + config_attr=("dense", "post_norm"), ), - *LlamaNormalizationConverter.get_converters( - config.routed.pre_norm, - f"{fast_llm_prefix}.routed.pre_norm", - f"{hf_prefix}.pre_feedforward_layernorm_2", - drop_on_export=drop_on_export, + "routed_pre_norm": NestedWeightConverter( + "routed.pre_norm", + "pre_feedforward_layernorm_2", + LlamaNormalizationConverter, + config_attr=("routed", "pre_norm"), ), - *LlamaNormalizationConverter.get_converters( - config.routed.post_norm, - f"{fast_llm_prefix}.routed.post_norm", - f"{hf_prefix}.post_feedforward_layernorm_2", - drop_on_export=drop_on_export, + "routed_post_norm": NestedWeightConverter( + "routed.post_norm", + "post_feedforward_layernorm_2", + LlamaNormalizationConverter, + config_attr=("routed", "post_norm"), ), - ] + } + +class Gemma4BlockConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. + fast_llm_config_class = DecoderBlockConfig + weight_only: typing.ClassVar[bool] = True -class Gemma4BlockConverter: @classmethod def import_config(cls, config: dict, is_sliding: bool) -> dict: def make_norm(): @@ -471,73 +461,44 @@ def export_config( return out @classmethod - def get_converters( - cls, - config: DecoderBlockConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - is_moe = isinstance(config.mlp, HybridMoEMLPConfig) - converters = [ - *Gemma4AttentionConverter.get_converters( - config.mixer, - f"{fast_llm_prefix}.mixer", - f"{hf_prefix}.self_attn", - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": NestedWeightConverter("mixer", "self_attn", Gemma4AttentionConverter), + # Dense MLP lands under ``mlp.<...>``; hybrid MoE flat-merges into the block's HF root. + "mlp": DispatchWeightConverter( + "mlp", + "mlp", + registry={MLPConfig: Gemma4MLPConverter, HybridMoEMLPConfig: Gemma4HybridMoEMLPConverter}, + hf_prefix_overrides={HybridMoEMLPConfig: ""}, ), - ] - if is_moe: - converters += Gemma4HybridMoEMLPConverter.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - hf_prefix, - drop_on_export=drop_on_export, - ) - else: - converters += Gemma4MLPConverter.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - f"{hf_prefix}.mlp", - drop_on_export=drop_on_export, - ) - converters += LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.pre_feedforward_layernorm", - drop_on_export=drop_on_export, - ) - converters += [ - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.input_layernorm", - drop_on_export=drop_on_export, + "norm_1": NestedWeightConverter( + "norm_1", "input_layernorm", LlamaNormalizationConverter, config_attr="normalization" ), - *LlamaNormalizationConverter.get_converters( - config.post_mixer_normalization, - f"{fast_llm_prefix}.post_mixer_norm", - f"{hf_prefix}.post_attention_layernorm", - drop_on_export=drop_on_export, + "norm_2": _Gemma4BlockNorm2WeightConverter(), + "post_mixer_norm": NestedWeightConverter( + "post_mixer_norm", + "post_attention_layernorm", + LlamaNormalizationConverter, + config_attr="post_mixer_normalization", ), - *LlamaNormalizationConverter.get_converters( - config.post_mlp_normalization, - f"{fast_llm_prefix}.post_mlp_norm", - f"{hf_prefix}.post_feedforward_layernorm", - drop_on_export=drop_on_export, + "post_mlp_norm": NestedWeightConverter( + "post_mlp_norm", + "post_feedforward_layernorm", + LlamaNormalizationConverter, + config_attr="post_mlp_normalization", ), - ] - converters.append( - get_parameter_converter( - f"{fast_llm_prefix}.output_scale", - f"{hf_prefix}.layer_scalar", - drop_on_export=drop_on_export, - ) - ) - return converters + # HF stores ``layer_scalar`` as a non-trained buffer; Fast-LLM mirrors it with a frozen + # ``output_scale`` (``lr_scale=0``). + "output_scale": WeightConverter("output_scale", "layer_scalar"), + } -class Gemma4DecoderConverter: +class Gemma4DecoderConverter(ConfigSectionConverter): + # Config side owned by the aggregator's ``decoder`` CustomConfigConverter; see Gemma4AttentionConverter. + fast_llm_config_class = PatternBlockSequenceConfig + weight_only: typing.ClassVar[bool] = True + block_converter_class: typing.ClassVar[type[Gemma4BlockConverter]] = Gemma4BlockConverter @classmethod @@ -573,24 +534,11 @@ def export_config(cls, config: PatternBlockSequenceConfig, hidden_size: int) -> ) @classmethod - def get_converters( - cls, - config: PatternBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - Assert.custom(isinstance, config, PatternBlockSequenceConfig) - converters = [] - for block_index in range(config.num_blocks): - block_config = config.blocks[config.expanded_pattern[block_index]] - converters += cls.block_converter_class.get_converters( - block_config, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export=drop_on_export, - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), + } class Gemma4EmbeddingsConverter(LlamaEmbeddingsConverter): @@ -645,7 +593,7 @@ def _gemma4_bidirectional_import(hf_dict: dict) -> dict: return {} -class Gemma4BaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class Gemma4BaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for ``GPTBaseModelConfig`` ↔ Gemma 4 HF dict. Gemma 4 has several wrinkles that prevent the standard per-section decomposition used by Llama: @@ -754,11 +702,15 @@ def _head_import(hf_dict: dict) -> dict: } @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), - *cls.head_converter_class.get_converters(config, exported_config), + *cls.embeddings_converter_class.emit_weight_converters( + config.embeddings, "embeddings", "model", root_config=config + ), + *cls.decoder_converter_class.emit_weight_converters( + config.decoder, "decoder", "model.layers", root_config=config + ), + *cls.head_converter_class.get_converters(config), ] diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 425a198f6..d57665e6c 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -1,23 +1,28 @@ import dataclasses +import functools import logging import typing -import torch import transformers from fast_llm.config import Config from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantImportConfigConverter, CustomConfigConverter, DefaultConfigConverter, IgnoredConfigConverter, - IgnoreExportWeightConverter, - IgnoreImportWeightConverter, + KeyValueWeightConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, + OutputProjectionWeightConverter, RenameConfigConverter, + SelfBlockSequenceWeightConverter, SplitWeightConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -32,14 +37,12 @@ from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig from fast_llm.layers.language_model.config import ( - LanguageModelConfig, LanguageModelEmbeddingsConfig, LanguageModelHeadConfig, ) from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import LlamaCheckpointFormat from fast_llm.models.gpt.model import GPTModel -from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div _TRANSFORMERS_V4 = not dataclasses.is_dataclass(transformers.PretrainedConfig) @@ -58,107 +61,6 @@ def effective_bias(layer_config: AffineLinearConfig, default: bool) -> bool: return default if layer_config.bias.enabled is None else layer_config.bias.enabled -# ============================================================ -# Weight converters (imperative) -# ============================================================ - - -def get_parameter_converter( - fast_llm_name: str | tuple[str, ...], - hf_name: str | tuple[str, ...], - cls=WeightConverter, - config=None, - drop_on_export: bool = False, - drop_on_import: bool = False, -) -> WeightConverter: - if isinstance(fast_llm_name, str): - fast_llm_name = (fast_llm_name,) - if isinstance(hf_name, str): - hf_name = (hf_name,) - if drop_on_export: - cls = IgnoreExportWeightConverter - if drop_on_import: - cls = IgnoreImportWeightConverter - return cls( - () if drop_on_import else fast_llm_name, - () if drop_on_export else hf_name, - config, - ) - - -def get_weight_and_bias_converters( - fast_llm_prefix: str | tuple[str, ...], - hf_prefix: str | tuple[str, ...], - use_bias: bool, - cls=WeightConverter, - config=None, - drop_on_export: bool = False, - drop_on_import: bool = False, -) -> list[WeightConverter]: - if isinstance(fast_llm_prefix, str): - fast_llm_prefix = (fast_llm_prefix,) - if isinstance(hf_prefix, str): - hf_prefix = (hf_prefix,) - converters = [ - get_parameter_converter( - () if drop_on_import else tuple(f"{prefix}.weight" for prefix in fast_llm_prefix), - () if drop_on_export else tuple(f"{prefix}.weight" for prefix in hf_prefix), - cls, - config, - drop_on_export, - drop_on_import, - ) - ] - if use_bias: - converters.append( - get_parameter_converter( - () if drop_on_import else tuple(f"{prefix}.bias" for prefix in fast_llm_prefix), - () if drop_on_export else tuple(f"{prefix}.bias" for prefix in hf_prefix), - cls, - config, - drop_on_export, - drop_on_import, - ) - ) - return converters - - -class MLPLayer2Converter(WeightConverter): - # Similar to SplitWeightConverter, but handles the optional MLP transpose. - # Still ok for non-gated (trivial split) and biases (trivial 1d transpose) - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (merged_weight,) = weight - return tuple(t.contiguous() for t in merged_weight[:].t().chunk(len(self.export_name), dim=-1)) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - merged_weight = torch.cat([weight_[:] for weight_ in weight], dim=-1) - return (merged_weight.t().contiguous(),) - - -class KeyValueWeightConverter(WeightConverter): - # Hf uses the real format for rotary embeddings, and keeps the key and value separate. - _config: AttentionConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - (key_value,) = weight - key, value = key_value[:].chunk(2) - return key, value - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - key, value = weight - key_value = torch.cat([key[:], value[:]]) - return (key_value,) - - # ============================================================ # Config converters (declarative) # ============================================================ @@ -256,19 +158,9 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: RMSNormalizationConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return get_weight_and_bias_converters( - fast_llm_prefix, - () if drop_on_export else hf_prefix, - False, - IgnoreExportWeightConverter if drop_on_export else WeightConverter, - ) + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return {"weight": WeightConverter("weight", "weight")} class LlamaMLPConverter(ConfigSectionConverter): @@ -303,29 +195,12 @@ def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, - config: MLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - (f"{hf_prefix}.gate_proj", f"{hf_prefix}.up_proj"), - config.add_linear_biases, - SplitWeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.down_proj", - config.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter("layer_1", ("gate_proj", "up_proj"), transform=SplitWeightConverter), + "layer_2": LinearWeightConverter("layer_2", "down_proj", transform=TransposeSplitWeightConverter), + } class LlamaAttentionConverter(ConfigSectionConverter): @@ -385,35 +260,13 @@ def _validate_export(cls, config: AttentionConfig) -> None: Assert.incl(config.dense_layer.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - config.add_linear_biases, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - config.add_linear_biases, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "query": LinearWeightConverter("query", "q_proj"), + "key_value": LinearWeightConverter("key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter), + "dense": LinearWeightConverter("dense", "o_proj"), + } class LlamaBlockConverter(ConfigSectionConverter): @@ -448,35 +301,18 @@ def _validate_export(cls, config: DecoderBlockConfig) -> None: Assert.custom(lambda v: not v, config.output_scale.enabled) @classmethod - def get_converters( - cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False - ) -> list[WeightConverter]: - return [ - *cls.mixer_converter_class.get_converters( - config.mixer, - f"{fast_llm_prefix}.mixer", - f"{hf_prefix}.{cls.hf_mixer_name}", - drop_on_export, - ), - *cls.mlp_converter_class.get_converters( - config.mlp, - f"{fast_llm_prefix}.mlp", - f"{hf_prefix}.{cls.hf_mlp_name}", - drop_on_export, - ), - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.{cls.hf_norm_1_name}", - drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": NestedWeightConverter("mixer", cls.hf_mixer_name, cls.mixer_converter_class), + "mlp": NestedWeightConverter("mlp", cls.hf_mlp_name, cls.mlp_converter_class), + "norm_1": NestedWeightConverter( + "norm_1", cls.hf_norm_1_name, cls.normalization_converter_class, config_attr="normalization" ), - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.{cls.hf_norm_2_name}", - drop_on_export, + "norm_2": NestedWeightConverter( + "norm_2", cls.hf_norm_2_name, cls.normalization_converter_class, config_attr="normalization" ), - ] + } def _llama_decoder_export( @@ -521,22 +357,12 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: FixedBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # The section config IS a ``FixedBlockSequenceConfig`` (no parent attribute holding it). + return { + "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), + } class LlamaEmbeddingsConverter(ConfigSectionConverter): @@ -562,10 +388,9 @@ def _validate_export(cls, config: LanguageModelEmbeddingsConfig) -> None: Assert.incl(config.position_embeddings.enabled, (None, False)) @classmethod - def get_converters( - cls, config: LanguageModelEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: - return [WeightConverter(f"{fast_llm_prefix}.word_embeddings_weight", f"{hf_prefix}.embed_tokens.weight")] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return {"word_embeddings": WeightConverter("word_embeddings_weight", "embed_tokens.weight")} class LlamaHeadConverter(ConfigSectionConverter): @@ -574,8 +399,6 @@ class LlamaHeadConverter(ConfigSectionConverter): fast_llm_config_class = LanguageModelHeadConfig normalization_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaNormalizationConverter - # Used by MTP-Llama subclass to emit per-prediction-head block weight converters; Llama itself doesn't read it. - block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = LlamaBlockConverter @classmethod def _create_config_converters(cls) -> dict: @@ -590,28 +413,32 @@ def _create_config_converters(cls) -> dict: "final_logit_softcap": ConstantImportConfigConverter(("final_logit_softcap",), None), } + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``final_norm`` reads the head's own ``normalization`` config; ``output_weights`` is the marker the + # walker drops automatically when the root config has ``tied_embedding_weight=True``. + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.norm", cls.normalization_converter_class, config_attr="normalization" + ), + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } + @classmethod def get_converters( cls, - config: LanguageModelConfig, - exported_config: dict, + config: GPTBaseModelConfig, ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.head.normalization, - f"head.final_norm", - f"model.norm", - ), - get_parameter_converter( - f"head.output_weights", - "lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], - drop_on_export=exported_config["tie_word_embeddings"], - ), - ] + """Aggregator entry-point: the base-model converter passes the full :class:`GPTBaseModelConfig` + so subclasses extending the head can read sibling sections (e.g. the decoder) when needed. + Tied-embedding handling lives on :class:`OutputProjectionWeightConverter` and reads + ``root_config.tied_embedding_weight``. + """ + return cls.emit_weight_converters(config.head, "head", "", root_config=config) -class LlamaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class LlamaBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for ``GPTBaseModelConfig`` ↔ Llama HF dict. Subclasses (Mistral, Qwen2, Mixtral, MTP-Llama, …) override ``block_converter_class`` to plug their @@ -667,22 +494,21 @@ def _validate_export(cls, config: GPTBaseModelConfig) -> None: assert_no_peft(config) @classmethod - def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - decoder_config = config.decoder - block_config = ( - decoder_config.block - if isinstance(decoder_config, FixedBlockSequenceConfig) - else next(iter(decoder_config.blocks.values())) - ) - block_converters: list[WeightConverter] = [] - for block_index in range(decoder_config.num_blocks): - block_converters += cls.block_converter_class.get_converters( - block_config, f"decoder.{block_index}", f"model.layers.{block_index}" - ) + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # ``head`` is added at the aggregator level (in :meth:`get_converters`) because the head + # converter takes the full base-model config so subclasses extending the head can read + # sibling sections. + return { + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": BlockSequenceWeightConverter("decoder", "model.layers", cls.block_converter_class), + } + + @classmethod + def get_converters(cls, config: GPTBaseModelConfig) -> list[WeightConverter]: return [ - *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), - *block_converters, - *cls.head_converter_class.get_converters(config, exported_config), + *cls.emit_weight_converters(config, "", ""), + *cls.head_converter_class.get_converters(config), ] diff --git a/fast_llm/models/gpt/conversion/mistral.py b/fast_llm/models/gpt/conversion/mistral.py index 18251c760..a9b77fde5 100644 --- a/fast_llm/models/gpt/conversion/mistral.py +++ b/fast_llm/models/gpt/conversion/mistral.py @@ -7,7 +7,6 @@ LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, ) @@ -39,13 +38,8 @@ class MistralBlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[MistralMLPConverter]] = MistralMLPConverter -class MistralHeadConverter(LlamaHeadConverter): - block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter - - class MistralBaseModelConverter(LlamaBaseModelConverter): block_converter_class: typing.ClassVar[type[MistralBlockConverter]] = MistralBlockConverter - head_converter_class: typing.ClassVar[type[MistralHeadConverter]] = MistralHeadConverter class MistralHuggingfaceCheckpointHandler(LlamaHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/mixtral.py b/fast_llm/models/gpt/conversion/mixtral.py index 0403413a9..4d377dcf4 100644 --- a/fast_llm/models/gpt/conversion/mixtral.py +++ b/fast_llm/models/gpt/conversion/mixtral.py @@ -1,20 +1,22 @@ +import functools import typing from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantImportConfigConverter, IgnoredConfigConverter, + LinearWeightConverter, RenameConfigConverter, SplitWeightConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.layers.decoder.mlp.config import MoEMLPConfig, RoutingType from fast_llm.models.gpt.conversion.config import MixtralCheckpointFormat -from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter, MLPLayer2Converter, get_weight_and_bias_converters +from fast_llm.models.gpt.conversion.llama import LlamaMLPConverter from fast_llm.models.gpt.conversion.mistral import ( MistralBaseModelConverter, MistralBlockConverter, - MistralHeadConverter, MistralHuggingfaceCheckpointHandler, ) from fast_llm.utils import Assert @@ -58,35 +60,21 @@ def _validate_export(cls, config: MoEMLPConfig) -> None: Assert.custom(lambda v: not v, config.router_per_expert_scale.enabled) @classmethod - def get_converters( - cls, - config: MoEMLPConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.router", - f"{hf_prefix}.gate", - False, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - tuple(f"{hf_prefix}.experts.{i}.{w}" for i in range(config.experts) for w in ("w1", "w3")), - False, - SplitWeightConverter, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "router": LinearWeightConverter("router", "gate"), + "layer_1": LinearWeightConverter( + "layer_1", + lambda c: tuple(f"experts.{i}.{w}" for i in range(c.experts) for w in ("w1", "w3")), + transform=SplitWeightConverter, ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - tuple(f"{hf_prefix}.experts.{i}.w2" for i in range(config.experts)), - False, - MLPLayer2Converter, - drop_on_export=drop_on_export, + "layer_2": LinearWeightConverter( + "layer_2", + lambda c: tuple(f"experts.{i}.w2" for i in range(c.experts)), + transform=TransposeSplitWeightConverter, ), - ] + } class MixtralBlockConverter(MistralBlockConverter): @@ -94,13 +82,8 @@ class MixtralBlockConverter(MistralBlockConverter): mlp_converter_class: typing.ClassVar[type[MixtralMLPConverter]] = MixtralMLPConverter -class MixtralHeadConverter(MistralHeadConverter): - block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter - - class MixtralBaseModelConverter(MistralBaseModelConverter): block_converter_class: typing.ClassVar[type[MixtralBlockConverter]] = MixtralBlockConverter - head_converter_class: typing.ClassVar[type[MixtralHeadConverter]] = MixtralHeadConverter class MixtralHuggingfaceCheckpointHandler(MistralHuggingfaceCheckpointHandler): diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 6f6d9e88a..1db31c6cd 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -1,22 +1,30 @@ +import functools import typing from transformers import PretrainedConfig from fast_llm.engine.checkpoint.config import CheckpointFormat -from fast_llm.engine.checkpoint.external import RenameConfigConverter, WeightConverter -from fast_llm.layers.language_model.config import LanguageModelConfig -from fast_llm.models.gpt.config import GPTModelConfig +from fast_llm.engine.checkpoint.external import ( + NestedWeightConverter, + OutputProjectionWeightConverter, + RenameConfigConverter, + WeightConverter, +) +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, + LlamaBlockConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, - get_parameter_converter, ) from fast_llm.utils import safe_merge_dicts class MTPLlamaHeadConverter(LlamaHeadConverter): + # The MTP block shape matches the main decoder block, so we plug ``LlamaBlockConverter`` in directly. + block_converter_class: typing.ClassVar[type[LlamaBlockConverter]] = LlamaBlockConverter + @classmethod def _create_config_converters(cls) -> dict: return { @@ -25,37 +33,37 @@ def _create_config_converters(cls) -> dict: "prediction_heads": RenameConfigConverter(("prediction_heads",), ("prediction_heads",)), } + @classmethod + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # MTP-Llama places the first prediction head's final norm under ``model.mtp_norms.0`` instead + # of the standard ``model.norm``; the additional MTP blocks/norms come from the imperative + # ``get_converters`` override below since their count depends on ``head.prediction_heads``. + return { + "final_norm": NestedWeightConverter( + "final_norm", "model.mtp_norms.0", cls.normalization_converter_class, config_attr="normalization" + ), + "output_weights": OutputProjectionWeightConverter("output_weights", "lm_head.weight"), + } + @classmethod def get_converters( cls, - config: LanguageModelConfig, - exported_config: dict, + config: GPTBaseModelConfig, ) -> list[WeightConverter]: - # MTP-Llama uses ``model.mtp_norms.0`` for the first prediction head's final norm - # instead of the standard ``model.norm``. - converters = [ - *cls.normalization_converter_class.get_converters( - config.head.normalization, - "head.final_norm", - "model.mtp_norms.0", - ), - get_parameter_converter( - "head.output_weights", - "lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], - drop_on_export=exported_config["tie_word_embeddings"], - ), - ] + converters = list(cls.emit_weight_converters(config.head, "head", "", root_config=config)) for prediction_distance in range(2, config.head.prediction_heads + 1): - converters += cls.block_converter_class.get_converters( + converters += cls.block_converter_class.emit_weight_converters( config.decoder.last_block_config, - f"multi_token_prediction.blocks.{prediction_distance-2}", + f"multi_token_prediction.blocks.{prediction_distance - 2}", f"model.mtp_heads.{prediction_distance - 2}", + root_config=config, ) - converters += cls.normalization_converter_class.get_converters( + converters += cls.normalization_converter_class.emit_weight_converters( config.head.normalization, f"multi_token_prediction.heads.{prediction_distance - 2}.final_norm", f"model.mtp_norms.{prediction_distance - 1}", + root_config=config, ) return converters diff --git a/fast_llm/models/gpt/conversion/qwen2.py b/fast_llm/models/gpt/conversion/qwen2.py index db2b35165..ec2a53c47 100644 --- a/fast_llm/models/gpt/conversion/qwen2.py +++ b/fast_llm/models/gpt/conversion/qwen2.py @@ -1,3 +1,4 @@ +import functools import typing from fast_llm.engine.checkpoint.config import CheckpointFormat @@ -5,6 +6,8 @@ ConstantImportConfigConverter, IgnoredConfigConverter, ImportOnlyConfigConverter, + KeyValueWeightConverter, + LinearWeightConverter, WeightConverter, ) from fast_llm.layers.attention.config import AttentionConfig @@ -12,14 +15,11 @@ from fast_llm.models.gpt.config import GPTBaseModelConfig from fast_llm.models.gpt.conversion.config import Qwen2CheckpointFormat from fast_llm.models.gpt.conversion.llama import ( - KeyValueWeightConverter, LlamaAttentionConverter, LlamaBaseModelConverter, LlamaBlockConverter, - LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, LlamaMLPConverter, - get_weight_and_bias_converters, ) from fast_llm.utils import Assert, div @@ -71,36 +71,18 @@ def _validate_export(cls, config: AttentionConfig) -> None: Assert.is_(config.value_layer.bias.enabled, True) Assert.incl(config.dense_layer.bias.enabled, (None, False)) + # Qwen2 hardcodes Q/K/V biases on, dense bias off — independent of ``add_linear_biases`` (which is + # pinned to False on the config side because there's no HF ``attention_bias`` field). @classmethod - def get_converters( - cls, - config: AttentionConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.query", - f"{hf_prefix}.q_proj", - True, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.key_value", - (f"{hf_prefix}.k_proj", f"{hf_prefix}.v_proj"), - True, - KeyValueWeightConverter, - config, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.dense", - f"{hf_prefix}.o_proj", - False, - drop_on_export=drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "query": LinearWeightConverter("query", "q_proj", bias_fn=True), + "key_value": LinearWeightConverter( + "key_value", ("k_proj", "v_proj"), transform=KeyValueWeightConverter, bias_fn=True ), - ] + "dense": LinearWeightConverter("dense", "o_proj", bias_fn=False), + } class Qwen2MLPConverter(LlamaMLPConverter): @@ -118,10 +100,6 @@ class Qwen2BlockConverter(LlamaBlockConverter): mlp_converter_class: typing.ClassVar[type[Qwen2MLPConverter]] = Qwen2MLPConverter -class Qwen2HeadConverter(LlamaHeadConverter): - block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter - - def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: if hf_dict.get("use_mrope") is True: raise NotImplementedError("MRoPE (use_mrope=True) is not supported by the Qwen2 converter") @@ -130,7 +108,6 @@ def _qwen2_mrope_guard_import(hf_dict: dict) -> dict: class Qwen2BaseModelConverter(LlamaBaseModelConverter): block_converter_class: typing.ClassVar[type[Qwen2BlockConverter]] = Qwen2BlockConverter - head_converter_class: typing.ClassVar[type[Qwen2HeadConverter]] = Qwen2HeadConverter @classmethod def _create_config_converters(cls) -> dict: diff --git a/fast_llm/models/multimodal/conversion/apriel2.py b/fast_llm/models/multimodal/conversion/apriel2.py index 6fe509a9d..d7832db19 100644 --- a/fast_llm/models/multimodal/conversion/apriel2.py +++ b/fast_llm/models/multimodal/conversion/apriel2.py @@ -1,17 +1,24 @@ """Apriel2 multimodal checkpoint format converter.""" +import functools import typing from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantExportConfigConverter, ConstantImportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, OptionalConfigConverter, + PatchEmbeddingWeightConverter, RenameConfigConverter, + SelfBlockSequenceWeightConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -22,27 +29,18 @@ from fast_llm.layers.common.normalization.config import RMSNormalizationConfig from fast_llm.layers.decoder.config import DecoderBlockConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.apriel2 import ( Apriel2BaseModelConverter, + Apriel2BlockConverter, Apriel2HeadConverter, Apriel2MLPConverter, Apriel2RMSNormConverter, - get_apriel2_decoder_converter, -) -from fast_llm.models.gpt.conversion.llama import ( - LlamaEmbeddingsConverter, - LlamaNormalizationConverter, - get_parameter_converter, - get_weight_and_bias_converters, ) +from fast_llm.models.gpt.conversion.llama import LlamaEmbeddingsConverter, LlamaNormalizationConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import Apriel2CheckpointFormat -from fast_llm.models.multimodal.conversion.llava import ( - PatchEmbeddingWeightConverter, - PixtralAttentionConverter, -) +from fast_llm.models.multimodal.conversion.llava import PixtralAttentionConverter from fast_llm.models.multimodal.model import MultiModalModel from fast_llm.utils import Assert @@ -198,29 +196,18 @@ def _validate_export(cls, config: DecoderBlockConfig) -> None: Assert.custom(lambda v: not v, config.output_scale.enabled) @classmethod - def get_converters( - cls, config: DecoderBlockConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False - ) -> list[WeightConverter]: - return [ - *cls.mixer_converter_class.get_converters( - config.mixer, f"{fast_llm_prefix}.mixer", f"{hf_prefix}.{cls.hf_mixer_name}", drop_on_export - ), - *cls.mlp_converter_class.get_converters( - config.mlp, f"{fast_llm_prefix}.mlp", f"{hf_prefix}.{cls.hf_mlp_name}", drop_on_export - ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_1", - f"{hf_prefix}.{cls.hf_norm_1_name}", - drop_on_export, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "mixer": NestedWeightConverter("mixer", cls.hf_mixer_name, cls.mixer_converter_class), + "mlp": NestedWeightConverter("mlp", cls.hf_mlp_name, cls.mlp_converter_class), + "norm_1": NestedWeightConverter( + "norm_1", cls.hf_norm_1_name, LlamaNormalizationConverter, config_attr="normalization" ), - *LlamaNormalizationConverter.get_converters( - config.normalization, - f"{fast_llm_prefix}.norm_2", - f"{hf_prefix}.{cls.hf_norm_2_name}", - drop_on_export, + "norm_2": NestedWeightConverter( + "norm_2", cls.hf_norm_2_name, LlamaNormalizationConverter, config_attr="normalization" ), - ] + } class Apriel2VisionEncoderConverter(ConfigSectionConverter): @@ -264,22 +251,11 @@ def _create_config_converters(cls) -> dict: } @classmethod - def get_converters( - cls, - config: FixedBlockSequenceConfig, - fast_llm_prefix: str, - hf_prefix: str, - drop_on_export: bool = False, - ) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - for block_index in range(config.num_blocks): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.{block_index}", - f"{hf_prefix}.{block_index}", - drop_on_export, - ) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "blocks": SelfBlockSequenceWeightConverter(cls.block_converter_class), + } class Apriel2EmbeddingsConverter(ConfigSectionConverter): @@ -322,21 +298,19 @@ def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) @classmethod - def get_converters( - cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.patch_embeddings", - f"{hf_prefix}.patch_embeddings", - False, - PatchEmbeddingWeightConverter, - config, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "patch_embeddings": LinearWeightConverter( + "patch_embeddings", + "patch_embeddings", + transform=PatchEmbeddingWeightConverter, + bias_fn=False, ), - *LlamaNormalizationConverter.get_converters( - config.normalization, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.normalization" + "normalization": NestedWeightConverter( + "normalization", "normalization", LlamaNormalizationConverter, config_attr="normalization" ), - ] + } class Apriel2VisionAdapterConverter(Apriel2VisionMLPConverter): @@ -353,27 +327,12 @@ def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters( - cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str, drop_on_export: bool = False - ) -> list[WeightConverter]: - from fast_llm.models.gpt.conversion.llama import MLPLayer2Converter - - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - f"{hf_prefix}.linear_1", - config.add_linear_biases, - WeightConverter, - drop_on_export=drop_on_export, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.linear_2", - config.add_linear_biases, - MLPLayer2Converter, - drop_on_export=drop_on_export, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter("layer_1", "linear_1"), + "layer_2": LinearWeightConverter("layer_2", "linear_2", transform=TransposeSplitWeightConverter), + } class Apriel2VisionModelConverter(ConfigSectionConverter): @@ -427,44 +386,23 @@ def _inject_rotary_metadata(config: VisionEncoderConfig) -> dict: return {} @classmethod - def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: - return [ - *cls.embeddings_converter_class.get_converters( - config.embeddings, "vision_encoder.embeddings", cls.hf_embeddings_prefix - ), - *cls.encoder_converter_class.get_converters( - config.encoder, "vision_encoder.encoder", cls.hf_encoder_prefix - ), - *cls.vision_adapter_converter_class.get_converters( - config.adapter, "vision_encoder.adapter", cls.hf_adapter_prefix - ), - ] - - -class Apriel2MultimodalHeadConverter(Apriel2HeadConverter): - @classmethod - def get_converters( - cls, - config: LanguageModelHeadConfig, - exported_config: dict, - fast_llm_prefix: str, - ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.normalization, - f"{fast_llm_prefix}.final_norm", - "model.norm", - ), - get_parameter_converter( - f"{fast_llm_prefix}.output_weights", - "lm_head.weight", - drop_on_import=exported_config.get("tie_word_embeddings", False), - drop_on_export=exported_config.get("tie_word_embeddings", False), + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + # The vision converter normally lives under a single ``vision_encoder`` HF prefix, but Apriel2's + # state-dict places each piece at distinct absolute paths (``model.vision_encoder.{embeddings, + # encoder.blocks, adapter}``). Use ``NestedWeightConverter`` with the absolute prefix on each entry; + # the parent's :class:`NestedWeightConverter("vision_encoder", "", ...)` passes ``hf_prefix=""``, + # so the absolute prefix lands as-is. + return { + "embeddings": NestedWeightConverter( + "embeddings", cls.hf_embeddings_prefix, cls.embeddings_converter_class ), - ] + "encoder": NestedWeightConverter("encoder", cls.hf_encoder_prefix, cls.encoder_converter_class), + "adapter": NestedWeightConverter("adapter", cls.hf_adapter_prefix, cls.vision_adapter_converter_class), + } -class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class Apriel2MultimodalBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for Apriel2 multimodal. Composes the Apriel2 text base (flat-merged into the HF top-level dict) with an optional vision encoder (under HF key ``vision_encoder``) and an optional ``image_token_index`` field. @@ -479,7 +417,8 @@ class Apriel2MultimodalBaseModelConverter(ConfigSectionConverter, HuggingFaceBas text_base_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BaseModelConverter vision_model_converter_class: typing.ClassVar[type[Apriel2VisionModelConverter]] = Apriel2VisionModelConverter embeddings_converter_class: typing.ClassVar[type[LlamaEmbeddingsConverter]] = LlamaEmbeddingsConverter - head_converter_class: typing.ClassVar[type[Apriel2MultimodalHeadConverter]] = Apriel2MultimodalHeadConverter + block_converter_class: typing.ClassVar[type[ConfigSectionConverter]] = Apriel2BlockConverter + head_converter_class: typing.ClassVar[type[Apriel2HeadConverter]] = Apriel2HeadConverter @classmethod def _create_config_converters(cls) -> dict: @@ -529,18 +468,19 @@ def _vision_import(hf_dict: dict) -> dict: } @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: - converters: list[WeightConverter] = [] - if config.vision_encoder is not None: - converters.extend(cls.vision_model_converter_class.get_converters(config.vision_encoder)) - converters.extend(cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model")) - converters.extend( - get_apriel2_decoder_converter(config.decoder).get_converters( - config.decoder, "decoder", "model.decoder.blocks" - ) - ) - converters.extend(cls.head_converter_class.get_converters(config.head, exported_config, "head")) - return converters + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + # Vision encoder is optional — ``optional=True`` skips the recursion when + # ``config.vision_encoder is None`` (text-only checkpoints). + "vision_encoder": NestedWeightConverter( + "vision_encoder", "", cls.vision_model_converter_class, optional=True + ), + # ``embeddings`` flat-merges into ``model``; vision_encoder writes to its own absolute prefix. + "embeddings": NestedWeightConverter("embeddings", "model", cls.embeddings_converter_class), + "decoder": BlockSequenceWeightConverter("decoder", "model.decoder.blocks", cls.block_converter_class), + "head": NestedWeightConverter("head", "", cls.head_converter_class), + } class Apriel2HuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @@ -586,7 +526,3 @@ def _export_config(cls, config: MultiModalModelConfig) -> dict[str, typing.Any]: def _import_config(cls, config: dict[str, typing.Any]) -> dict[str, typing.Any]: cls._check_hf_coverage(config) return {"base_model": cls.base_model_converter_class.import_config(config)} - - @classmethod - def _get_weight_converters(cls, config: MultiModalModelConfig, export_config: dict) -> list[WeightConverter]: - return cls.base_model_converter_class.get_converters(config.base_model, export_config) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index fe18ef3cb..1a67cf8ec 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -1,17 +1,21 @@ +import functools import typing -import torch - from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( + BlockSequenceWeightConverter, ConfigSectionConverter, ConstantExportConfigConverter, ConstantImportConfigConverter, CustomConfigConverter, IgnoredConfigConverter, ImportOnlyConfigConverter, + LinearWeightConverter, NestedConfigConverter, + NestedWeightConverter, + PatchEmbeddingWeightConverter, RenameConfigConverter, + TransposeSplitWeightConverter, WeightConverter, ) from fast_llm.engine.checkpoint.huggingface import HuggingFaceBaseModelConverter, HuggingfaceStateDictCheckpointHandler @@ -19,25 +23,20 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import AttentionConfig from fast_llm.layers.attention.rotary.config import Rotary2DConfig -from fast_llm.layers.block.config import FixedBlockSequenceConfig from fast_llm.layers.decoder.mlp.config import MLPConfig -from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionEncoderConfig from fast_llm.models.gpt.conversion.llama import ( _TRANSFORMERS_V4, LlamaAttentionConverter, LlamaBlockConverter, LlamaDecoderConverter, + LlamaHeadConverter, LlamaNormalizationConverter, - MLPLayer2Converter, - get_parameter_converter, - get_weight_and_bias_converters, ) -from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralHeadConverter, MistralMLPConverter +from fast_llm.models.gpt.conversion.mistral import MistralBaseModelConverter, MistralMLPConverter from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalModelConfig from fast_llm.models.multimodal.conversion.config import LlavaCheckpointFormat from fast_llm.models.multimodal.model import MultiModalModel -from fast_llm.tensor import SafeTensorSlice from fast_llm.utils import Assert, div @@ -123,34 +122,6 @@ class PixtralEncoderConverter(LlamaDecoderConverter): block_converter_class: typing.ClassVar[type[PixtralBlockConverter]] = PixtralBlockConverter -class PatchEmbeddingWeightConverter(WeightConverter): - _config: PatchEmbeddingsConfig - - def export_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return tuple( - weight_[:].view( - *weight_[:].shape[:-1], - self._config.input_channels, - self._config.patch_height, - self._config.patch_width, - ) - for weight_ in weight - ) - - def import_weight( - self, weight: tuple[torch.Tensor | SafeTensorSlice, ...] - ) -> tuple[torch.Tensor | SafeTensorSlice, ...]: - return tuple( - weight_[:].view( - *weight_[:].shape[:-3], - self._config.input_channels * self._config.patch_height * self._config.patch_width, - ) - for weight_ in weight - ) - - class PixtralEmbeddingsConverter(ConfigSectionConverter): """Converts ``PatchEmbeddingsConfig`` ↔ Pixtral HF flat fields (``patch_size`` / ``num_channels``). @@ -184,21 +155,22 @@ def _validate_export(cls, config: PatchEmbeddingsConfig) -> None: Assert.incl(config.patch_embeddings.bias.enabled, (None, False)) @classmethod - def get_converters( - cls, config: PatchEmbeddingsConfig, fast_llm_prefix: str, hf_prefix: str - ) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.patch_embeddings", - f"{hf_prefix}.patch_conv", - False, - PatchEmbeddingWeightConverter, - config, + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "patch_embeddings": LinearWeightConverter( + "patch_embeddings", + "patch_conv", + transform=PatchEmbeddingWeightConverter, + bias_fn=False, ), - *cls.normalization_converter_class.get_converters( - config, f"{fast_llm_prefix}.normalization", f"{hf_prefix}.ln_pre" + # ``PixtralEmbeddingsConverter``'s section config IS the ``PatchEmbeddingsConfig`` (carries the + # ``normalization`` sub-config directly), so the nested ``LlamaNormalizationConverter`` reads + # ``getattr(section_config, "normalization")``. + "normalization": NestedWeightConverter( + "normalization", "ln_pre", cls.normalization_converter_class, config_attr="normalization" ), - ] + } class LlavaVisionAdapterConverter(ConfigSectionConverter): @@ -243,21 +215,12 @@ def _validate_export(cls, config: MLPConfig) -> None: Assert.incl(config.layer_2.bias.enabled, (None, config.add_linear_biases)) @classmethod - def get_converters(cls, config: MLPConfig, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]: - return [ - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_1", - f"{hf_prefix}.linear_1", - config.add_linear_biases, - WeightConverter, - ), - *get_weight_and_bias_converters( - f"{fast_llm_prefix}.layer_2", - f"{hf_prefix}.linear_2", - config.add_linear_biases, - MLPLayer2Converter, - ), - ] + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "layer_1": LinearWeightConverter("layer_1", "linear_1"), + "layer_2": LinearWeightConverter("layer_2", "linear_2", transform=TransposeSplitWeightConverter), + } class LlavaVisionModelConverter(ConfigSectionConverter): @@ -315,46 +278,41 @@ def _validate_export(cls, config: VisionEncoderConfig) -> None: Assert.eq(mixer.head_size * mixer.heads, config.hidden_size) @classmethod - def get_converters(cls, config: VisionEncoderConfig) -> list[WeightConverter]: - return [ - *cls.embeddings_converter_class.get_converters( - config.embeddings, "vision_encoder.embeddings", "vision_tower" - ), - *cls.encoder_converter_class.get_converters( - config.encoder, "vision_encoder.encoder", "vision_tower.transformer.layers" - ), - *LlavaVisionAdapterConverter.get_converters( - config.adapter, "vision_encoder.adapter", "multi_modal_projector" + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "embeddings": NestedWeightConverter("embeddings", "vision_tower", cls.embeddings_converter_class), + # The encoder section IS a FixedBlockSequenceConfig — fan out blocks via the dedicated primitive. + "encoder": NestedWeightConverter( + "encoder", "vision_tower.transformer.layers", cls.encoder_converter_class ), - ] + "adapter": NestedWeightConverter("adapter", "multi_modal_projector", LlavaVisionAdapterConverter), + } -class LlavaHeadConverter(MistralHeadConverter): +class LlavaHeadConverter(LlamaHeadConverter): + # Llava always emits a separate ``language_model.lm_head.weight`` declaration even when + # ``tied_embedding_weight=True``, so the head uses a plain rename instead of + # :class:`OutputProjectionWeightConverter` (which drops on export under the tied flag). @classmethod - def get_converters( - cls, - config: LanguageModelConfig, - exported_config: dict, - ) -> list[WeightConverter]: - return [ - *cls.normalization_converter_class.get_converters( - config.head.normalization, - f"head.final_norm", - f"language_model.model.norm", - ), - get_parameter_converter( - f"head.output_weights", - "language_model.lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: + return { + "final_norm": NestedWeightConverter( + "final_norm", + "language_model.model.norm", + cls.normalization_converter_class, + config_attr="normalization", ), - ] + "output_weights": WeightConverter("output_weights", "language_model.lm_head.weight"), + } class LlavaLanguageModelConverter(MistralBaseModelConverter): head_converter_class: typing.ClassVar[type[LlavaHeadConverter]] = LlavaHeadConverter -class LlavaBaseModelConverter(ConfigSectionConverter, HuggingFaceBaseModelConverter): +class LlavaBaseModelConverter(HuggingFaceBaseModelConverter): """Top-level converter for Llava. Composes: * ``text_config`` HF subdict ← :class:`LlavaLanguageModelConverter` (Mistral text base). @@ -428,27 +386,21 @@ def _validate_export(cls, config: MultiModalBaseModelConfig) -> None: assert config.image_token_index is not None, "Llava requires an image_token_index" @classmethod - def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict) -> list[WeightConverter]: + @functools.cache + def _create_weight_converters(cls) -> dict[str, WeightConverter]: text_base_cls = cls.language_model_converter_class - decoder_config = config.decoder - block_config = ( - decoder_config.block - if isinstance(decoder_config, FixedBlockSequenceConfig) - else next(iter(decoder_config.blocks.values())) - ) - block_converters: list[WeightConverter] = [] - for block_index in range(decoder_config.num_blocks): - block_converters += text_base_cls.block_converter_class.get_converters( - block_config, f"decoder.{block_index}", f"language_model.model.layers.{block_index}" - ) - return [ - *cls.vision_model_converter_class.get_converters(config.vision_encoder), - *text_base_cls.embeddings_converter_class.get_converters( - config.embeddings, "embeddings", "language_model.model" + return { + "vision_encoder": NestedWeightConverter("vision_encoder", "", cls.vision_model_converter_class), + "embeddings": NestedWeightConverter( + "embeddings", "language_model.model", text_base_cls.embeddings_converter_class ), - *block_converters, - *text_base_cls.head_converter_class.get_converters(config, {"tie_word_embeddings": False}), - ] + "decoder": BlockSequenceWeightConverter( + "decoder", "language_model.model.layers", text_base_cls.block_converter_class + ), + # ``LlavaHeadConverter``'s leaf converters use absolute HF paths (``language_model.lm_head.weight``, + # ``language_model.model.norm``), so an empty ``hf_prefix`` lets them land verbatim. + "head": NestedWeightConverter("head", "", text_base_cls.head_converter_class), + } class LlavaHuggingfaceCheckpointHandler(HuggingfaceStateDictCheckpointHandler):