From cb21557cc04e5a67414753f3cff1b8bcf9f9272a Mon Sep 17 00:00:00 2001 From: Adrian Caderno Date: Fri, 10 Apr 2026 18:25:40 +0200 Subject: [PATCH] Add spatial shape constraint docs and test for SwinUNETR (#6771) Signed-off-by: Adrian Caderno --- monai/networks/nets/swin_unetr.py | 20 +++++++++++++++++++- tests/networks/nets/test_swin_unetr.py | 11 +++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index b4d93c9afe..0db2d50d26 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -47,6 +47,19 @@ class SwinUNETR(nn.Module): Swin UNETR based on: "Hatamizadeh et al., Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images " + + Spatial Shape Constraints: + Each spatial dimension of the input must be divisible by ``patch_size ** 5``. + With the default ``patch_size=2``, this means each spatial dimension must be divisible by **32** + (i.e., 2^5 = 32). This requirement comes from the patch embedding step followed by 4 stages + of PatchMerging downsampling, each halving the spatial resolution. + + For a custom ``patch_size``, the divisibility requirement is ``patch_size ** 5``. + + Examples of valid 3D input sizes (with default ``patch_size=2``): + ``(32, 32, 32)``, ``(64, 64, 64)``, ``(96, 96, 96)``, ``(128, 128, 128)``, ``(64, 32, 192)``. + + A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint. """ def __init__( @@ -76,7 +89,8 @@ def __init__( Args: in_channels: dimension of input channels. out_channels: dimension of output channels. - patch_size: size of the patch token. + patch_size: size of the patch token. Input spatial dimensions must be divisible by + ``patch_size ** 5`` (e.g., divisible by 32 when ``patch_size=2``). feature_size: dimension of network feature size. depths: number of layers in each stage. num_heads: number of attention heads. @@ -108,6 +122,10 @@ def __init__( # for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing. >>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2) + Raises: + ValueError: When a spatial dimension of the input is not divisible by ``patch_size ** 5``. + Use ``net._check_input_size(spatial_shape)`` to validate a shape before inference. + """ super().__init__() diff --git a/tests/networks/nets/test_swin_unetr.py b/tests/networks/nets/test_swin_unetr.py index cc15158b43..ba94aab4f9 100644 --- a/tests/networks/nets/test_swin_unetr.py +++ b/tests/networks/nets/test_swin_unetr.py @@ -90,6 +90,17 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SwinUNETR(in_channels=1, out_channels=3, feature_size=24, norm_name="instance", drop_rate=-1) + @skipUnless(has_einops, "Requires einops") + def test_invalid_input_shape(self): + # spatial dims not divisible by patch_size**5 (default patch_size=2, so must be divisible by 32) + net = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=3) + with self.assertRaises(ValueError): + net(torch.randn(1, 1, 33, 64, 64)) # 33 is not divisible by 32 + + net_2d = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=2) + with self.assertRaises(ValueError): + net_2d(torch.randn(1, 1, 48, 33)) # 33 is not divisible by 32 + def test_patch_merging(self): dim = 10 t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))