Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
551 changes: 525 additions & 26 deletions fast_llm/engine/checkpoint/external.py

Large diffs are not rendered by default.

35 changes: 14 additions & 21 deletions fast_llm/engine/checkpoint/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,27 @@
import transformers


class HuggingFaceBaseModelConverter:
@classmethod
@abc.abstractmethod
def import_config(cls, config: dict) -> dict:
pass
class HuggingFaceBaseModelConverter(ConfigSectionConverter):
"""Section converter for a full HF model root. Inherits the declarative config-side machinery from
:class:`ConfigSectionConverter` (``import_config`` / ``export_config`` driven by
``_create_config_converters``) and adds the weight-side ``get_converters`` aggregation that the
enclosing :class:`HuggingfaceStateDictCheckpointHandler` invokes.
"""

@classmethod
@abc.abstractmethod
def export_config(cls, config: BaseModelConfig) -> dict:
pass
def get_converters(cls, config: BaseModelConfig) -> list[WeightConverter]:
"""Default: walk the section's weight declarations from the root.

@classmethod
@abc.abstractmethod
def get_converters(cls, config: BaseModelConfig, exported_config: dict) -> list[WeightConverter]:
pass
Subclasses override when a section needs cross-section state from the full base-model config
(typically when an extension point on the head must read from a sibling section).
"""
return cls.emit_weight_converters(config, "", "")


class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC):
architecture: typing.ClassVar[str]
base_model_converter_class: typing.ClassVar[type[HuggingFaceBaseModelConverter]]

def __init__(self, model: "FastLLMModel"):
self._exported_config = self._export_config(model.config)
super().__init__(model)

@classmethod
@abc.abstractmethod
def get_transformers_configuration_class(cls) -> type["transformers.PretrainedConfig"]:
Expand Down Expand Up @@ -164,13 +160,10 @@ def _export_config(cls, config: FastLLMModelConfig) -> dict[str, typing.Any]:
def _check_hf_coverage(cls, config: dict[str, typing.Any]) -> None:
"""Run the HF-side coverage check at the import boundary.

Skips silently when the format's base-model converter isn't a :class:`ConfigSectionConverter`
(e.g. multimodal aggregators built on top of imperative ``HuggingFaceBaseModelConverter``).
Subclasses that override :meth:`_import_config` should call this explicitly to keep the check
active.
"""
if issubclass(cls.base_model_converter_class, ConfigSectionConverter):
cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST)
cls.base_model_converter_class.check_hf_coverage(config, allowlist=cls._HF_METADATA_ALLOWLIST)

@classmethod
def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
Expand All @@ -180,7 +173,7 @@ def _import_config(cls, config: dict[str, typing.Any]) -> FastLLMModelConfig:
return cls._model_class.from_dict({"base_model": cls.base_model_converter_class.import_config(config)})

def _create_weight_converters(self) -> list[WeightConverter]:
return self.base_model_converter_class.get_converters(self._model.config.base_model, self._exported_config)
return self.base_model_converter_class.get_converters(self._model.config.base_model)

def _load_weights(
self, config: CheckpointLoadConfig, device
Expand Down
Loading
Loading