Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ class SwinUNETR(nn.Module):
Swin UNETR based on: "Hatamizadeh et al.,
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
<https://arxiv.org/abs/2201.01266>"

Spatial Shape Constraints:
Each spatial dimension of the input must be divisible by ``patch_size ** 5``.
With the default ``patch_size=2``, this means each spatial dimension must be divisible by **32**
(i.e., 2^5 = 32). This requirement comes from the patch embedding step followed by 4 stages
of PatchMerging downsampling, each halving the spatial resolution.

For a custom ``patch_size``, the divisibility requirement is ``patch_size ** 5``.

Examples of valid 3D input sizes (with default ``patch_size=2``):
``(32, 32, 32)``, ``(64, 64, 64)``, ``(96, 96, 96)``, ``(128, 128, 128)``, ``(64, 32, 192)``.

A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.
"""

def __init__(
Expand Down Expand Up @@ -76,7 +89,8 @@ def __init__(
Args:
in_channels: dimension of input channels.
out_channels: dimension of output channels.
patch_size: size of the patch token.
patch_size: size of the patch token. Input spatial dimensions must be divisible by
``patch_size ** 5`` (e.g., divisible by 32 when ``patch_size=2``).
feature_size: dimension of network feature size.
depths: number of layers in each stage.
num_heads: number of attention heads.
Expand Down Expand Up @@ -108,6 +122,10 @@ def __init__(
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)

Raises:
ValueError: When a spatial dimension of the input is not divisible by ``patch_size ** 5``.
Use ``net._check_input_size(spatial_shape)`` to validate a shape before inference.

"""

super().__init__()
Expand Down
11 changes: 11 additions & 0 deletions tests/networks/nets/test_swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SwinUNETR(in_channels=1, out_channels=3, feature_size=24, norm_name="instance", drop_rate=-1)

@skipUnless(has_einops, "Requires einops")
def test_invalid_input_shape(self):
# spatial dims not divisible by patch_size**5 (default patch_size=2, so must be divisible by 32)
net = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=3)
with self.assertRaises(ValueError):
net(torch.randn(1, 1, 33, 64, 64)) # 33 is not divisible by 32

net_2d = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=2)
with self.assertRaises(ValueError):
net_2d(torch.randn(1, 1, 48, 33)) # 33 is not divisible by 32

def test_patch_merging(self):
dim = 10
t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))
Expand Down
Loading