From 8a1ae86ea920424daab04c99cc9534759d681ef9 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sat, 7 Mar 2026 10:14:02 +0100 Subject: [PATCH] fix: torch.compile changes the FQNs of parameters by adding "_orig_mod." as a prefix to the original FQN. This causes the regexes for matching parameter names to fail. To fix this, we need to remove the "_orig_mod." prefix from the parameter names before matching them against the regexes. This change needs to be made in both the llama3_like_initialization.py and initialization_routines.py files, wherever we are matching parameter names against regexes. --- src/modalities/models/gpt2/llama3_like_initialization.py | 7 +++++-- .../nn/model_initialization/initialization_routines.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index d82a5be2c..5d1fa53af 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -99,10 +99,10 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: } def initialize_in_place(self, model: nn.Module): - self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init) + self._init_by_fqn_regex(model, self.regex_to_init) @staticmethod - def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool): + def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]): hits = {k: 0 for k in regex_to_init.keys()} for parameter_name, p in model.named_parameters(): @@ -112,6 +112,9 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable ) match_count = 0 for weight_regex in regex_to_init.keys(): + parameter_name = parameter_name.replace( + "_orig_mod.", "" + ) # remove FQN modification from torch.compile if present if re.fullmatch(weight_regex, parameter_name): init_fn, arg_dict = regex_to_init[weight_regex] if arg_dict["std"] is not None and callable(arg_dict["std"]): diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index 5b4515875..deb6a2737 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -48,6 +48,9 @@ def initialize_in_place(self, model: nn.Module): weight_regexes = self.parameter_name_regexes.weights bias_regexes = self.parameter_name_regexes.biases for parameter_name, p in model.named_parameters(): + parameter_name = parameter_name.replace( + "_orig_mod.", "" + ) # remove FQN modification from torch.compile if present for weight_regex in weight_regexes: if re.fullmatch(weight_regex, parameter_name): nn.init.normal_(p, mean=self.mean, std=self.std)