From eb55abe76e45bf3215eaf747496f4987afe903dc Mon Sep 17 00:00:00 2001 From: Oleksandr Sanin Date: Wed, 27 May 2026 09:21:35 +0000 Subject: [PATCH 1/4] fix(networks): replace Tensor | None union syntax with Optional[Tensor] for TorchScript compatibility The `|` union type syntax (e.g. `torch.Tensor | None`) was introduced in Python 3.10. While `from __future__ import annotations` defers evaluation at runtime, TorchScript's annotation parser does not support this syntax and fails when scripting models that contain these forward method signatures. Replace `torch.Tensor | None` with `Optional[torch.Tensor]` in the `forward` methods of: - `monai/networks/blocks/crossattention.py` (CrossAttentionBlock) - `monai/networks/blocks/selfattention.py` (SABlock) - `monai/networks/blocks/transformerblock.py` (TransformerBlock) These three blocks are used in the ViT/UNETR scripting path, causing `RuntimeError: Can't redefine method: forward on class` when `torch.jit.script()` is called on a UNETR model. Closes #7939 Signed-off-by: Oleksandr Sanin --- monai/networks/blocks/crossattention.py | 3 ++- monai/networks/blocks/selfattention.py | 3 ++- monai/networks/blocks/transformerblock.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 40722ae881..66317149f1 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -139,7 +140,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index d2ad24ac19..3ff16b05ca 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -158,7 +159,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x, attn_mask: torch.Tensor | None = None): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index b93d81bdef..02a4186fbd 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from typing import Optional from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock @@ -89,7 +90,7 @@ def __init__( ) def forward( - self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None + self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: From a80169b595e4071b55747ee7ae30fb7366598bfd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 May 2026 17:09:44 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/crossattention.py | 3 +-- monai/networks/blocks/selfattention.py | 3 +-- monai/networks/blocks/transformerblock.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 66317149f1..40722ae881 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -140,7 +139,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 3ff16b05ca..d2ad24ac19 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -14,7 +14,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -159,7 +158,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): + def forward(self, x, attn_mask: torch.Tensor | None = None): """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 02a4186fbd..b93d81bdef 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -13,7 +13,6 @@ import torch import torch.nn as nn -from typing import Optional from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock @@ -90,7 +89,7 @@ def __init__( ) def forward( - self, x: torch.Tensor, context: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None + self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None ) -> torch.Tensor: x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: From d18c3f98b424fdda580a1be8bf0d69f0b5b32874 Mon Sep 17 00:00:00 2001 From: Oleksandr Sanin Date: Fri, 29 May 2026 09:39:55 +0000 Subject: [PATCH 3/4] fix(networks): use Optional[Tensor] with noqa to preserve TorchScript compat TorchScript's annotation parser does not support the PEP 604 `X | None` union syntax. Replace `torch.Tensor | None` with `Optional[torch.Tensor]` in the `forward` methods of CrossAttentionBlock, SABlock, and TransformerBlock. Add `# noqa: UP045` on each affected line so ruff (pyupgrade) does not auto-revert the annotations back to the `X | None` form. Closes #7939 Signed-off-by: Oleksandr Sanin --- monai/networks/blocks/crossattention.py | 3 ++- monai/networks/blocks/selfattention.py | 3 ++- monai/networks/blocks/transformerblock.py | 6 +++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index 40722ae881..f14cd64d1f 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -139,7 +140,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None): # noqa: UP045 """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index d2ad24ac19..4fdd99c921 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -14,6 +14,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import @@ -158,7 +159,7 @@ def __init__( ) self.input_size = input_size - def forward(self, x, attn_mask: torch.Tensor | None = None): + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): # noqa: UP045 """ Args: x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index b93d81bdef..024899236f 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -13,6 +13,7 @@ import torch import torch.nn as nn +from typing import Optional from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock @@ -89,7 +90,10 @@ def __init__( ) def forward( - self, x: torch.Tensor, context: torch.Tensor | None = None, attn_mask: torch.Tensor | None = None + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, # noqa: UP045 + attn_mask: Optional[torch.Tensor] = None, # noqa: UP045 ) -> torch.Tensor: x = x + self.attn(self.norm1(x), attn_mask=attn_mask) if self.with_cross_attention: From 176357d8ebd37908d118f688fcf72181854c3921 Mon Sep 17 00:00:00 2001 From: Oleksandr Sanin Date: Fri, 29 May 2026 13:36:29 +0000 Subject: [PATCH 4/4] fix(networks): sort typing import and prune cross-attn branch for TorchScript Two follow-up fixes so that scripting UNETR (test_unetr::test_script) fully passes and the static-checks (codeformat) job is green: 1. Move ``from typing import Optional`` into the standard-library import group (before the third-party ``torch`` imports) in crossattention.py, selfattention.py and transformerblock.py. isort (profile=black) requires this ordering; without it the codeformat check failed. 2. Add ``__constants__ = ["with_cross_attention"]`` to TransformerBlock. PR #8848 made ``cross_attn`` an ``nn.Identity`` when ``with_cross_attention`` is False. TorchScript statically compiles every branch, so the ``self.cross_attn(..., context=context)`` call was checked against ``nn.Identity.forward`` (which has no ``context`` argument) and scripting failed. Marking the flag as a TorchScript constant lets the compiler prune the dead cross-attention branch when it is False, while still keeping ``cross_attn`` as ``nn.Identity`` (no registered params), preserving the behaviour and tests added in #8848. The typing.Final annotation form cannot be used here because ``from __future__ import annotations`` stringizes it and TorchScript cannot resolve ``'Final[bool]'``; the ``__constants__`` list avoids that. Closes #7939 Signed-off-by: Oleksandr Sanin --- monai/networks/blocks/crossattention.py | 3 ++- monai/networks/blocks/selfattention.py | 3 ++- monai/networks/blocks/transformerblock.py | 8 +++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index f14cd64d1f..f6e930dfff 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -11,9 +11,10 @@ from __future__ import annotations +from typing import Optional + import torch import torch.nn as nn -from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 4fdd99c921..0f5fd73ca0 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -11,10 +11,11 @@ from __future__ import annotations +from typing import Optional + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional from monai.networks.layers.utils import get_rel_pos_embedding_layer from monai.utils import optional_import diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 1b45204c46..bb030a4d75 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -11,9 +11,10 @@ from __future__ import annotations +from typing import Optional + import torch import torch.nn as nn -from typing import Optional from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock @@ -24,6 +25,11 @@ class TransformerBlock(nn.Module): An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " """ + # Treat ``with_cross_attention`` as a TorchScript constant so the cross-attention branch in + # ``forward`` is pruned when it is False. Otherwise scripting tries to compile the + # ``self.cross_attn(..., context=context)`` call against ``nn.Identity`` and fails. + __constants__ = ["with_cross_attention"] + def __init__( self, hidden_size: int,