Skip to content

Commit 5a2d0a7

Browse files
ytl0623ericspod
andauthored
fix(autoencoderkl): handle proj_attnout_proj key mapping in load_old_state_dict (#8786)
Fixes #8544 ### Description Map `proj_attn` → `out_proj` when both exist, initialise `out_proj` to identity/zero when only the new model has it, and silently discard `proj_attn` when only the old checkpoint has it. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 6cd452e commit 5a2d0a7

2 files changed

Lines changed: 135 additions & 7 deletions

File tree

monai/networks/nets/autoencoderkl.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
680680
681681
Args:
682682
old_state_dict: state dict from the old AutoencoderKL model.
683+
verbose: if True, print diagnostic information about key mismatches.
683684
"""
684685

685686
new_state_dict = self.state_dict()
@@ -715,13 +716,39 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
715716
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
716717
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
717718

718-
# old version did not have a projection so set these to the identity
719-
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
720-
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
721-
)
722-
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
723-
new_state_dict[f"{block}.attn.out_proj.bias"].shape
724-
)
719+
out_w = f"{block}.attn.out_proj.weight"
720+
out_b = f"{block}.attn.out_proj.bias"
721+
proj_w = f"{block}.proj_attn.weight"
722+
proj_b = f"{block}.proj_attn.bias"
723+
724+
if out_w in new_state_dict:
725+
if proj_w in old_state_dict:
726+
new_state_dict[out_w] = old_state_dict.pop(proj_w)
727+
if proj_b in old_state_dict:
728+
new_state_dict[out_b] = old_state_dict.pop(proj_b)
729+
else:
730+
new_state_dict[out_b] = torch.zeros(
731+
new_state_dict[out_b].shape,
732+
dtype=new_state_dict[out_b].dtype,
733+
device=new_state_dict[out_b].device,
734+
)
735+
else:
736+
# No legacy proj_attn - initialize out_proj to identity/zero
737+
new_state_dict[out_w] = torch.eye(
738+
new_state_dict[out_w].shape[0],
739+
dtype=new_state_dict[out_w].dtype,
740+
device=new_state_dict[out_w].device,
741+
)
742+
new_state_dict[out_b] = torch.zeros(
743+
new_state_dict[out_b].shape,
744+
dtype=new_state_dict[out_b].dtype,
745+
device=new_state_dict[out_b].device,
746+
)
747+
elif proj_w in old_state_dict:
748+
# new model has no out_proj at all - discard the legacy keys so they
749+
# don't surface as "unexpected keys" during load_state_dict
750+
old_state_dict.pop(proj_w)
751+
old_state_dict.pop(proj_b, None)
725752

726753
# fix the upsample conv blocks which were renamed postconv
727754
for k in new_state_dict:

tests/networks/nets/test_autoencoderkl.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@
169169

170170

171171
class TestAutoEncoderKL(unittest.TestCase):
172+
_MIGRATION_PARAMS = {
173+
"spatial_dims": 2,
174+
"in_channels": 1,
175+
"out_channels": 1,
176+
"channels": (4, 4, 4),
177+
"latent_channels": 4,
178+
"attention_levels": (False, False, False),
179+
"num_res_blocks": 1,
180+
"norm_num_groups": 4,
181+
}
182+
172183
@parameterized.expand(CASES)
173184
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
174185
net = AutoencoderKL(**input_param).to(device)
@@ -327,6 +338,96 @@ def test_compatibility_with_monai_generative(self):
327338

328339
net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)
329340

341+
@staticmethod
342+
def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict:
343+
"""Convert new-style state dict keys to legacy naming conventions.
344+
345+
Args:
346+
new_sd: State dict with current key naming.
347+
include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`.
348+
349+
Returns:
350+
State dict with legacy key names.
351+
"""
352+
old_sd: dict = {}
353+
for k, v in new_sd.items():
354+
if ".attn.to_q." in k:
355+
old_sd[k.replace(".attn.to_q.", ".to_q.")] = v.clone()
356+
elif ".attn.to_k." in k:
357+
old_sd[k.replace(".attn.to_k.", ".to_k.")] = v.clone()
358+
elif ".attn.to_v." in k:
359+
old_sd[k.replace(".attn.to_v.", ".to_v.")] = v.clone()
360+
elif ".attn.out_proj." in k:
361+
if include_proj_attn:
362+
old_sd[k.replace(".attn.out_proj.", ".proj_attn.")] = v.clone()
363+
elif "postconv" in k:
364+
old_sd[k.replace("postconv", "conv")] = v.clone()
365+
else:
366+
old_sd[k] = v.clone()
367+
return old_sd
368+
369+
@skipUnless(has_einops, "Requires einops")
370+
def test_load_old_state_dict_proj_attn_copied_to_out_proj(self):
371+
params = {**self._MIGRATION_PARAMS, "include_fc": True}
372+
src = AutoencoderKL(**params).to(device)
373+
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=True)
374+
375+
# record the tensor values that were stored under proj_attn
376+
expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k}
377+
self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict - check model config")
378+
379+
dst = AutoencoderKL(**params).to(device)
380+
dst.load_old_state_dict(old_sd)
381+
382+
for new_key, expected_val in expected.items():
383+
torch.testing.assert_close(
384+
dst.state_dict()[new_key], expected_val.to(device), msg=f"Weight mismatch for {new_key}"
385+
)
386+
387+
@skipUnless(has_einops, "Requires einops")
388+
def test_load_old_state_dict_missing_proj_attn_initialises_identity(self):
389+
params = {**self._MIGRATION_PARAMS, "include_fc": True}
390+
src = AutoencoderKL(**params).to(device)
391+
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False)
392+
393+
dst = AutoencoderKL(**params).to(device)
394+
dst.load_old_state_dict(old_sd)
395+
loaded = dst.state_dict()
396+
397+
out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k]
398+
out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k]
399+
self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found - check model config")
400+
401+
for k in out_proj_weights:
402+
n = loaded[k].shape[0]
403+
torch.testing.assert_close(
404+
loaded[k], torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix"
405+
)
406+
for k in out_proj_biases:
407+
torch.testing.assert_close(loaded[k], torch.zeros_like(loaded[k]), msg=f"{k} should be all-zeros")
408+
409+
@skipUnless(has_einops, "Requires einops")
410+
def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self):
411+
params = {**self._MIGRATION_PARAMS, "include_fc": False}
412+
src = AutoencoderKL(**params).to(device)
413+
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False)
414+
415+
# inject synthetic proj_attn keys (mimic an old checkpoint)
416+
attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")]
417+
self.assertGreater(len(attn_blocks), 0, "No attention blocks found - check model config")
418+
for block in attn_blocks:
419+
ch = old_sd[f"{block}.to_q.weight"].shape[0]
420+
old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch)
421+
old_sd[f"{block}.proj_attn.bias"] = torch.randn(ch)
422+
423+
dst = AutoencoderKL(**params).to(device)
424+
dst.load_old_state_dict(old_sd)
425+
426+
loaded = dst.state_dict()
427+
self.assertFalse(
428+
any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False"
429+
)
430+
330431

331432
if __name__ == "__main__":
332433
unittest.main()

0 commit comments

Comments
 (0)