|
169 | 169 |
|
170 | 170 |
|
171 | 171 | class TestAutoEncoderKL(unittest.TestCase): |
| 172 | + _MIGRATION_PARAMS = { |
| 173 | + "spatial_dims": 2, |
| 174 | + "in_channels": 1, |
| 175 | + "out_channels": 1, |
| 176 | + "channels": (4, 4, 4), |
| 177 | + "latent_channels": 4, |
| 178 | + "attention_levels": (False, False, False), |
| 179 | + "num_res_blocks": 1, |
| 180 | + "norm_num_groups": 4, |
| 181 | + } |
| 182 | + |
172 | 183 | @parameterized.expand(CASES) |
173 | 184 | def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape): |
174 | 185 | net = AutoencoderKL(**input_param).to(device) |
@@ -327,6 +338,96 @@ def test_compatibility_with_monai_generative(self): |
327 | 338 |
|
328 | 339 | net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False) |
329 | 340 |
|
| 341 | + @staticmethod |
| 342 | + def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict: |
| 343 | + """Convert new-style state dict keys to legacy naming conventions. |
| 344 | +
|
| 345 | + Args: |
| 346 | + new_sd: State dict with current key naming. |
| 347 | + include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`. |
| 348 | +
|
| 349 | + Returns: |
| 350 | + State dict with legacy key names. |
| 351 | + """ |
| 352 | + old_sd: dict = {} |
| 353 | + for k, v in new_sd.items(): |
| 354 | + if ".attn.to_q." in k: |
| 355 | + old_sd[k.replace(".attn.to_q.", ".to_q.")] = v.clone() |
| 356 | + elif ".attn.to_k." in k: |
| 357 | + old_sd[k.replace(".attn.to_k.", ".to_k.")] = v.clone() |
| 358 | + elif ".attn.to_v." in k: |
| 359 | + old_sd[k.replace(".attn.to_v.", ".to_v.")] = v.clone() |
| 360 | + elif ".attn.out_proj." in k: |
| 361 | + if include_proj_attn: |
| 362 | + old_sd[k.replace(".attn.out_proj.", ".proj_attn.")] = v.clone() |
| 363 | + elif "postconv" in k: |
| 364 | + old_sd[k.replace("postconv", "conv")] = v.clone() |
| 365 | + else: |
| 366 | + old_sd[k] = v.clone() |
| 367 | + return old_sd |
| 368 | + |
| 369 | + @skipUnless(has_einops, "Requires einops") |
| 370 | + def test_load_old_state_dict_proj_attn_copied_to_out_proj(self): |
| 371 | + params = {**self._MIGRATION_PARAMS, "include_fc": True} |
| 372 | + src = AutoencoderKL(**params).to(device) |
| 373 | + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=True) |
| 374 | + |
| 375 | + # record the tensor values that were stored under proj_attn |
| 376 | + expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k} |
| 377 | + self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict - check model config") |
| 378 | + |
| 379 | + dst = AutoencoderKL(**params).to(device) |
| 380 | + dst.load_old_state_dict(old_sd) |
| 381 | + |
| 382 | + for new_key, expected_val in expected.items(): |
| 383 | + torch.testing.assert_close( |
| 384 | + dst.state_dict()[new_key], expected_val.to(device), msg=f"Weight mismatch for {new_key}" |
| 385 | + ) |
| 386 | + |
| 387 | + @skipUnless(has_einops, "Requires einops") |
| 388 | + def test_load_old_state_dict_missing_proj_attn_initialises_identity(self): |
| 389 | + params = {**self._MIGRATION_PARAMS, "include_fc": True} |
| 390 | + src = AutoencoderKL(**params).to(device) |
| 391 | + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False) |
| 392 | + |
| 393 | + dst = AutoencoderKL(**params).to(device) |
| 394 | + dst.load_old_state_dict(old_sd) |
| 395 | + loaded = dst.state_dict() |
| 396 | + |
| 397 | + out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k] |
| 398 | + out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k] |
| 399 | + self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found - check model config") |
| 400 | + |
| 401 | + for k in out_proj_weights: |
| 402 | + n = loaded[k].shape[0] |
| 403 | + torch.testing.assert_close( |
| 404 | + loaded[k], torch.eye(n, dtype=loaded[k].dtype, device=device), msg=f"{k} should be an identity matrix" |
| 405 | + ) |
| 406 | + for k in out_proj_biases: |
| 407 | + torch.testing.assert_close(loaded[k], torch.zeros_like(loaded[k]), msg=f"{k} should be all-zeros") |
| 408 | + |
| 409 | + @skipUnless(has_einops, "Requires einops") |
| 410 | + def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self): |
| 411 | + params = {**self._MIGRATION_PARAMS, "include_fc": False} |
| 412 | + src = AutoencoderKL(**params).to(device) |
| 413 | + old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False) |
| 414 | + |
| 415 | + # inject synthetic proj_attn keys (mimic an old checkpoint) |
| 416 | + attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")] |
| 417 | + self.assertGreater(len(attn_blocks), 0, "No attention blocks found - check model config") |
| 418 | + for block in attn_blocks: |
| 419 | + ch = old_sd[f"{block}.to_q.weight"].shape[0] |
| 420 | + old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch) |
| 421 | + old_sd[f"{block}.proj_attn.bias"] = torch.randn(ch) |
| 422 | + |
| 423 | + dst = AutoencoderKL(**params).to(device) |
| 424 | + dst.load_old_state_dict(old_sd) |
| 425 | + |
| 426 | + loaded = dst.state_dict() |
| 427 | + self.assertFalse( |
| 428 | + any("out_proj" in k for k in loaded), "out_proj should not exist in a model built with include_fc=False" |
| 429 | + ) |
| 430 | + |
330 | 431 |
|
331 | 432 | if __name__ == "__main__": |
332 | 433 | unittest.main() |
0 commit comments