Skip to content

Commit 8a1ae86

Browse files
committed
fix: torch.compile changes the FQNs of parameters by adding "_orig_mod." as a prefix to the original FQN. This causes the regexes for matching parameter names to fail. To fix this, we need to remove the "_orig_mod." prefix from the parameter names before matching them against the regexes. This change needs to be made in both the llama3_like_initialization.py and initialization_routines.py files, wherever we are matching parameter names against regexes.
1 parent 627f84e commit 8a1ae86

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/modalities/models/gpt2/llama3_like_initialization.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
9999
}
100100

101101
def initialize_in_place(self, model: nn.Module):
102-
self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init)
102+
self._init_by_fqn_regex(model, self.regex_to_init)
103103

104104
@staticmethod
105-
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool):
105+
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]]):
106106
hits = {k: 0 for k in regex_to_init.keys()}
107107

108108
for parameter_name, p in model.named_parameters():
@@ -112,6 +112,9 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable
112112
)
113113
match_count = 0
114114
for weight_regex in regex_to_init.keys():
115+
parameter_name = parameter_name.replace(
116+
"_orig_mod.", ""
117+
) # remove FQN modification from torch.compile if present
115118
if re.fullmatch(weight_regex, parameter_name):
116119
init_fn, arg_dict = regex_to_init[weight_regex]
117120
if arg_dict["std"] is not None and callable(arg_dict["std"]):

src/modalities/nn/model_initialization/initialization_routines.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def initialize_in_place(self, model: nn.Module):
4848
weight_regexes = self.parameter_name_regexes.weights
4949
bias_regexes = self.parameter_name_regexes.biases
5050
for parameter_name, p in model.named_parameters():
51+
parameter_name = parameter_name.replace(
52+
"_orig_mod.", ""
53+
) # remove FQN modification from torch.compile if present
5154
for weight_regex in weight_regexes:
5255
if re.fullmatch(weight_regex, parameter_name):
5356
nn.init.normal_(p, mean=self.mean, std=self.std)

0 commit comments

Comments
 (0)