Skip to content

Commit bb698d6

Browse files
Merge branch 'dev' into improve-affine-docs-7092
2 parents 8565dd5 + 2147c11 commit bb698d6

File tree

3 files changed

+128
-14
lines changed

3 files changed

+128
-14
lines changed

monai/losses/unified_focal_loss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
Args:
4545
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
4646
delta : weight of the background. Defaults to 0.7.
47-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
47+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
4848
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
4949
"""
5050
super().__init__(reduction=LossReduction(reduction).value)
@@ -108,7 +108,7 @@ def __init__(
108108
Args:
109109
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
110110
delta : weight of the background. Defaults to 0.7.
111-
gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75.
111+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
112112
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
113113
"""
114114
super().__init__(reduction=LossReduction(reduction).value)
@@ -167,10 +167,11 @@ def __init__(
167167
Args:
168168
to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
169169
num_classes : number of classes, it only supports 2 now. Defaults to 2.
170+
weight : weight for each loss function. Defaults to 0.5.
171+
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
170172
delta : weight of the background. Defaults to 0.7.
171-
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
172-
epsilon : it defines a very small number each time. similarly smooth value. Defaults to 1e-7.
173-
weight : weight for each loss function, if it's none it's 0.5. Defaults to None.
173+
174+
174175
175176
Example:
176177
>>> import torch

monai/transforms/croppad/functional.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,23 +91,27 @@ def pad_nd(
9191
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
9292
kwargs: other arguments for the `np.pad` or `torch.pad` function.
9393
note that `np.pad` treats channel dimension as the first dimension.
94+
Raises:
95+
ValueError: If `value` is provided when `mode` is not ``"constant"``.
9496
"""
97+
if mode != "constant" and "value" in kwargs:
98+
raise ValueError("'value' argument is only valid when mode='constant'")
9599
if mode in {"linear_ramp", "maximum", "mean", "median", "minimum", "symmetric", "empty"}:
96100
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
97101
try:
98102
_pad = _np_pad
99-
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"} and img.dtype not in {
100-
torch.int16,
101-
torch.int64,
102-
torch.bool,
103-
torch.uint8,
104-
}:
103+
if mode in {"constant", "reflect", "edge", "replicate", "wrap", "circular"}:
104+
# Try PyTorch pad for these modes; fallback to NumPy on error.
105105
_pad = _pt_pad
106106
return _pad(img, pad_width=to_pad, mode=mode, **kwargs)
107+
except NotImplementedError:
108+
# PyTorch does not support this combination, fall back to NumPy
109+
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
107110
except (ValueError, TypeError, RuntimeError) as err:
108-
if isinstance(err, NotImplementedError) or any(
109-
k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")
110-
):
111+
# PyTorch may raise generic errors for unsupported modes/dtypes or kwargs.
112+
# Since there are no stable exception types for these cases, we fall back
113+
# to NumPy by matching known error message patterns.
114+
if any(k in str(err) for k in ("supported", "unexpected keyword", "implemented", "value")):
111115
return _np_pad(img, pad_width=to_pad, mode=mode, **kwargs)
112116
raise ValueError(
113117
f"{img.shape} {to_pad} {mode} {kwargs} {img.dtype} {img.device if isinstance(img, torch.Tensor) else None}"
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
"""
12+
Tests for pad_nd dtype support and backend selection.
13+
Validates PyTorch padding preference and NumPy fallback behavior.
14+
"""
15+
from __future__ import annotations
16+
17+
import unittest
18+
from unittest.mock import Mock, patch
19+
20+
import torch
21+
from parameterized.parameterized import parameterized
22+
23+
import monai.transforms.croppad.functional as F
24+
from monai.transforms.croppad.functional import pad_nd
25+
26+
DTYPES = [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]
27+
MODES_DTYPES = [
28+
("constant", torch.bool),
29+
("constant", torch.int8),
30+
("constant", torch.float32),
31+
("reflect", torch.bool),
32+
("reflect", torch.int8),
33+
("reflect", torch.float32),
34+
("replicate", torch.bool),
35+
("replicate", torch.int8),
36+
("replicate", torch.float32),
37+
]
38+
39+
40+
class TestPadNdDtypes(unittest.TestCase):
41+
def test_pad_uses_pt_for_bool(self):
42+
"""Test that pad_nd uses PyTorch backend for bool dtype in constant mode."""
43+
img = torch.ones((1, 4, 4), dtype=torch.bool)
44+
to_pad = [(0, 0), (1, 1), (2, 2)]
45+
with (
46+
patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt,
47+
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
48+
):
49+
out = pad_nd(img, to_pad, mode="constant", value=0)
50+
51+
self.assertTrue(mock_pt.called)
52+
self.assertFalse(mock_np.called)
53+
self.assertEqual(out.dtype, img.dtype)
54+
self.assertEqual(out.shape, (1, 6, 8))
55+
56+
def test_pad_falls_back_to_np_if_pt_raises(self):
57+
"""Test that pad_nd falls back to NumPy when PyTorch raises NotImplementedError."""
58+
img = torch.ones((1, 4, 4), dtype=torch.bool)
59+
to_pad = [(0, 0), (1, 1), (2, 2)]
60+
with (
61+
patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt,
62+
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
63+
):
64+
out = pad_nd(img, to_pad, mode="constant", value=0)
65+
66+
self.assertTrue(mock_pt.called)
67+
self.assertTrue(mock_np.called)
68+
self.assertEqual(out.dtype, img.dtype)
69+
self.assertEqual(out.shape, (1, 6, 8))
70+
71+
@parameterized.expand(DTYPES)
72+
def test_pad_dtype_no_error_and_dtype_preserved(self, dtype):
73+
"""Test that pad_nd handles various dtypes without error and preserves dtype.
74+
Args:
75+
dtype: Input dtype under test.
76+
"""
77+
img = torch.ones((1, 4, 4), dtype=dtype)
78+
to_pad = [(0, 0), (1, 1), (2, 2)]
79+
out = pad_nd(img, to_pad, mode="constant", value=0)
80+
81+
self.assertEqual(out.shape, (1, 6, 8))
82+
self.assertEqual(out.dtype, img.dtype)
83+
84+
@parameterized.expand(MODES_DTYPES)
85+
def test_pad_multiple_modes_dtype_preserved(self, mode, dtype):
86+
"""Test that pad_nd preserves dtype across multiple padding modes.
87+
Args:
88+
mode: Padding mode under test.
89+
dtype: Input dtype under test.
90+
"""
91+
img = torch.ones((1, 4, 4), dtype=dtype)
92+
to_pad = [(0, 0), (1, 1), (2, 2)]
93+
94+
kwargs = {"value": 0} if mode == "constant" else {}
95+
out = pad_nd(img, to_pad, mode=mode, **kwargs)
96+
97+
self.assertEqual(out.shape, (1, 6, 8))
98+
self.assertEqual(out.dtype, img.dtype)
99+
100+
def test_value_with_non_constant_mode_raises(self):
101+
"""Test that pad_nd raises ValueError when 'value' is provided with non-constant mode."""
102+
img = torch.ones((1, 4, 4))
103+
to_pad = [(0, 0), (1, 1), (2, 2)]
104+
with self.assertRaises(ValueError):
105+
pad_nd(img, to_pad, mode="reflect", value=0)
106+
107+
108+
if __name__ == "__main__":
109+
unittest.main()

0 commit comments

Comments
 (0)