Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6720c18
Reclassify architecture-impacting fields under FieldHint.architecture
jlamypoirier May 5, 2026
0d393b2
Add declarative ConfigConverter primitives and section-converter ABC
jlamypoirier May 5, 2026
0c406db
Migrate Llama config converters to declarative primitives
jlamypoirier May 5, 2026
0a061c4
Address PR review: validator hook + cleanup
jlamypoirier May 6, 2026
79f8364
Mark PatternBlockSequenceConfig.blocks as architecture
jlamypoirier May 6, 2026
0484dc5
Extend converter framework with nested-HF and typed-dict primitives
jlamypoirier May 6, 2026
fcbd282
Migrate Apriel2 config converters to declarative primitives
jlamypoirier May 6, 2026
f510383
Migrate Mistral / Qwen2 / MTP-Llama config converters to declarative …
jlamypoirier May 6, 2026
df438c4
Migrate Mixtral config converters to declarative primitives
jlamypoirier May 6, 2026
1b025db
Migrate Apriel hybrid SSM mixer config converters to declarative prim…
jlamypoirier May 6, 2026
17b91d9
Migrate Pixtral normalization and embeddings config converters to dec…
jlamypoirier May 6, 2026
dc418d1
Remove unused weight-converter scaffolding
jlamypoirier May 6, 2026
8272abf
Self-review fixes
jlamypoirier May 7, 2026
8314f12
Address review feedback
jlamypoirier May 7, 2026
eb9f179
Address second review round
jlamypoirier May 8, 2026
0588262
Address third review round
jlamypoirier May 8, 2026
7d013e1
Add HF-side coverage check on import
jlamypoirier May 11, 2026
536d548
Address fourth review round
jlamypoirier May 11, 2026
d34aac4
Address fifth review round
jlamypoirier May 11, 2026
5858160
Address sixth review round
jlamypoirier May 12, 2026
807fbe1
Migrate Apriel2 multimodal config converters to declarative
jlamypoirier May 14, 2026
6cfcc2c
Migrate Llava multimodal config converters to declarative
jlamypoirier May 14, 2026
d35e39c
Flatten LlamaDecoderConverter chain + Qwen2 MRoPE declarative
jlamypoirier May 14, 2026
b3b41b7
Claim transformers metadata keys in Llava vision_config
jlamypoirier May 14, 2026
70c63ba
Apply HF metadata allowlist recursively in coverage check
jlamypoirier May 14, 2026
1b2cf9d
Claim transformers' Pixtral and Llava HF defaults
jlamypoirier May 14, 2026
f023f9e
Merge main: cover new architecture fields in declarative converters
jlamypoirier May 14, 2026
5b31071
Port Gemma4BaseModelConverter to ConfigSectionConverter
jlamypoirier May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Tests live in `tests/`. The following patterns work well in this codebase.

- **Comments**: Write no comments by default. Only add one when the *why* is non-obvious — a hidden constraint, a subtle invariant, a workaround for a specific bug, behavior that would surprise a reader. Never restate what the code already says; well-named identifiers do that.
- **Imports**: Third-party → `import package.module` (keep fully qualified). First-party → `from fast_llm.module import Thing`. No relative imports. Optional/slow imports inside methods or under `if typing.TYPE_CHECKING:`.
- **Naming**: No abbreviations (use `batch_size` not `bs`). Private members get a single `_` prefix; never use `__`. Keep public interfaces lean.
- **Naming**: No abbreviations (use `batch_size` not `bs`). Non-public members (private or protected) get a single `_` prefix; never use `__`. Keep public interfaces lean.
- **Types**: Always type-hint public interfaces. Use modern syntax (`X | Y`, `list[T]` not `List[T]`, PEP 695 generics like `class X[T: Bound]:` instead of `typing.TypeVar`).
- **Assert**: Use the `Assert` namespace from `fast_llm.utils` for contract checks (`Assert.eq`, `Assert.geq`, `Assert.incl`, `Assert.custom`, etc.) — error messages auto-format with actual values. Bare `assert` is reserved for internal state-validity invariants (`assert self._is_setup`).
- **Exceptions**: Raise stdlib exceptions for runtime errors (`ValueError`, `RuntimeError`, `NotImplementedError`). Custom exception classes are rare — only `ValidationError`, `NestedValidationError`, `FieldTypeError` in `config.py`.
Expand Down
674 changes: 661 additions & 13 deletions fast_llm/engine/checkpoint/external.py

Large diffs are not rendered by default.

55 changes: 54 additions & 1 deletion fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig
from fast_llm.engine.checkpoint.external import ExternalStateDictCheckpointHandler, WeightConverter, logger
from fast_llm.engine.checkpoint.external import (
ConfigSectionConverter,
ExternalStateDictCheckpointHandler,
WeightConverter,
logger,
)
from fast_llm.engine.multi_stage.config import CheckpointMetadata, FastLLMModelConfig
from fast_llm.tensor import SafeTensorSlice
from fast_llm.utils import Assert, safe_merge_dicts
Expand Down Expand Up @@ -120,10 +125,58 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
},
)

# Top-level HF metadata keys that are always permitted, regardless of the converter tree.
# Covers transformers' generic ``PretrainedConfig`` fields (always present after ``to_dict()``)
# plus a handful of widely-shared metadata that Fast-LLM intentionally does not store.
_HF_METADATA_ALLOWLIST: typing.ClassVar[frozenset[str]] = frozenset(
{
# transformers PretrainedConfig
"_name_or_path",
"architectures",
"auto_map",
"chunk_size_feed_forward",
"dtype",
"id2label",
"is_encoder_decoder",
"label2id",
"model_type",
"output_attentions",
"output_hidden_states",
"problem_type",
"return_dict",
"torch_dtype",
"transformers_version",
"use_cache",
# Token ids — generation/inference, not architecture.
"bos_token_id",
"decoder_start_token_id",
"eos_token_id",
"pad_token_id",
"sep_token_id",
# Initialization / pretraining metadata Fast-LLM does not consume.
"initializer_range",
"max_position_embeddings",
"pretraining_tp",
}
)

@classmethod
def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None:
"""Run the HF-side coverage check at the import boundary.

Skips silently when the format's base-model converter isn't a :class:`ConfigSectionConverter`
(e.g. multimodal aggregators built on top of imperative ``HuggingFaceBaseModelConverter``).
Subclasses that override :meth:`_import_config` should call this explicitly to keep the check
active.
"""
if issubclass(cls.base_model_converter_class, ConfigSectionConverter):
cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST)

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Assert.eq(config["model_type"], cls.get_huggingface_model_type())
Assert.eq(config["architectures"], [cls.architecture])
cls._check_hf_coverage(config)
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})

def _create_weight_converters(self) -> list[WeightConverter]:
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class AttentionConfig(MixerConfig):
)
dense_layer: AffineLinearConfig = Field(
desc="Initialization configuration for the dense layer.",
hint=FieldHint.feature,
hint=FieldHint.architecture,
)
# TODO: Review names
rotary: RotaryConfig = Field(
Expand Down Expand Up @@ -116,6 +116,7 @@ class AttentionConfig(MixerConfig):
" Under Standard Parameterization (SP): default to 0.5. "
" Under muP (if scaling head_size size): use 1. "
" Under muP (if scaling number of heads instead of head_size): use 0.5.",
hint=FieldHint.architecture,
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
implementation: AttentionImplementation = Field(
Expand Down
18 changes: 9 additions & 9 deletions fast_llm/layers/attention/rotary/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ class Llama3RotaryConfig(DefaultRotaryConfig):
"""

# TODO: Add descriptions.
scale_factor: float = Field(default=8.0, hint=FieldHint.feature)
low_frequency_factor: float = Field(default=1.0, hint=FieldHint.feature)
high_frequency_factor: float = Field(default=4.0, hint=FieldHint.feature)
original_context_length: int = Field(default=8192, hint=FieldHint.feature)
scale_factor: float = Field(default=8.0, hint=FieldHint.architecture)
low_frequency_factor: float = Field(default=1.0, hint=FieldHint.architecture)
high_frequency_factor: float = Field(default=4.0, hint=FieldHint.architecture)
original_context_length: int = Field(default=8192, hint=FieldHint.architecture)

def _validate(self) -> None:
super()._validate()
Expand All @@ -103,20 +103,20 @@ class YarnRotaryConfig(DefaultRotaryConfig):
"""

# TODO: Add descriptions.
scale_factor: float = Field(default=8.0, hint=FieldHint.feature)
scale_factor: float = Field(default=8.0, hint=FieldHint.architecture)
attention_factor: None | float = Field(
default=None,
hint=FieldHint.feature,
hint=FieldHint.architecture,
)
beta_fast: float = Field(
default=32.0,
hint=FieldHint.feature,
hint=FieldHint.architecture,
)
beta_slow: float = Field(
default=1.0,
hint=FieldHint.feature,
hint=FieldHint.architecture,
)
original_context_length: int = Field(default=8192, hint=FieldHint.feature)
original_context_length: int = Field(default=8192, hint=FieldHint.architecture)

def _validate(self) -> None:
if self.attention_factor is None:
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/layers/block/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def last_block_config(self) -> BlockConfig:
@config_class(dynamic_type={BlockSequenceConfig: "pattern"})
class PatternBlockSequenceConfig(BlockSequenceConfig):
_abstract = False
blocks: dict[str, BlockConfig] = Field()
blocks: dict[str, BlockConfig] = Field(
desc="Named block configurations referenced by `pattern`.",
hint=FieldHint.architecture,
)
pattern: list[str] = Field(
default=None,
desc="The name of each block (key in `blocks`) in the repeated pattern.",
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/layers/decoder/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class StochasticMixerConfig(MixerConfig):
"Used for inference/eval, checkpoint loading (receives pretrained weights), "
"and checkpoint saving (only this mixer is exported). "
"If None, uses the first mixer in the dict.",
hint=FieldHint.feature,
hint=FieldHint.architecture,
)

predefined_layouts: list[list[str]] = Field(
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/decoder/mlp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MLPConfig(MLPBaseConfig):
activation: ActivationType = Field(
default=None,
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
hint=FieldHint.core,
hint=FieldHint.architecture,
)
# normalization_implementation: NormalizationImplementation = NormalizationImplementation.auto
recompute_level: MLPRecomputeLevel = Field(
Expand Down Expand Up @@ -97,7 +97,7 @@ class MoEMLPConfig(MLPConfig):
router: LinearConfig = Field(
# TODO: Improve default?
desc="Configuration for the MoE router.",
hint=FieldHint.feature,
hint=FieldHint.architecture,
)
router_normalization: NormalizationConfig | None = Field(
default=None,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/layers/vision/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ class PatchEmbeddingsConfig(BlockConfig):
patch_height: int = Field(
default=16,
desc="Height of image patches, in pixels.",
hint=FieldHint.core,
hint=FieldHint.architecture,
)
patch_width: int = Field(
default=16,
desc="Width of image patches, in pixels.",
hint=FieldHint.core,
hint=FieldHint.architecture,
)
full_precision_residual: bool = Field(
default=False,
Expand Down
Loading
Loading