Skip to content

Commit 34d1be6

Browse files
committed
feat: implement ignore_index for losses and metrics
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 958dac7 commit 34d1be6

9 files changed

Lines changed: 421 additions & 117 deletions

File tree

monai/check_ignore.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
import warnings
3+
import importlib.util
4+
from monai.losses import DiceLoss, TverskyLoss, FocalLoss
5+
from monai.metrics import DiceMetric, MeanIoU, ConfusionMatrixMetric, HausdorffDistanceMetric
6+
from monai.networks import one_hot
7+
8+
class IgnoreIndexTester:
9+
def __init__(self):
10+
# Target: [Batch, Channel, H, W]
11+
# Bottom row is ignored (255)
12+
self.target_ignore = torch.tensor(
13+
[[[[1, 0],
14+
[255, 255]]]],
15+
dtype=torch.long,
16+
)
17+
18+
self.target_std = torch.tensor(
19+
[[[[1, 0],
20+
[0, 1]]]],
21+
dtype=torch.long,
22+
)
23+
24+
self.input_1 = torch.tensor(
25+
[[[[2.0, -2.0],
26+
[5.0, -5.0]]]]
27+
)
28+
29+
self.input_2 = torch.tensor(
30+
[[[[2.0, -2.0],
31+
[-5.0, 5.0]]]]
32+
)
33+
34+
self.results = []
35+
36+
def log_result(self, name, v1, v2, should_match, mode):
37+
if isinstance(v1, (list, tuple)):
38+
v1 = v1[0]
39+
if isinstance(v2, (list, tuple)):
40+
v2 = v2[0]
41+
42+
match = torch.allclose(v1.float(), v2.float(), atol=1e-5, equal_nan=True)
43+
success = match == should_match
44+
status = "PASS" if success else "FAIL"
45+
self.results.append(f"{name:18} [{mode:8}] : {status}")
46+
47+
if not success:
48+
print(f"DEBUG {name} ({mode}) -> {v1.item():.6f} vs {v2.item():.6f}")
49+
50+
def run_loss_test(self, name, loss_cls, **kwargs):
51+
crit_ignore = loss_cls(ignore_index=255, **kwargs)
52+
v1_ign = crit_ignore(self.input_1, self.target_ignore)
53+
v2_ign = crit_ignore(self.input_2, self.target_ignore)
54+
self.log_result(name, v1_ign, v2_ign, True, "ignore")
55+
56+
crit_std = loss_cls(**kwargs)
57+
v1_std = crit_std(self.input_1, self.target_std)
58+
v2_std = crit_std(self.input_2, self.target_std)
59+
self.log_result(name, v1_std, v2_std, False, "standard")
60+
61+
def run_metric_test(self, name, metric_cls, **kwargs):
62+
# Probabilities for classes 0 and 1
63+
p1 = torch.cat([1 - torch.sigmoid(self.input_1), torch.sigmoid(self.input_1)], dim=1)
64+
p2 = torch.cat([1 - torch.sigmoid(self.input_2), torch.sigmoid(self.input_2)], dim=1)
65+
66+
def eval_metric(m, p_probs, t):
67+
m.reset()
68+
69+
# Hausdorff and distance metrics usually require One-Hot inputs
70+
if "Hausdorff" in name:
71+
y_pred = (p_probs > 0.5).float()
72+
# Clean target for one_hot conversion
73+
t_clean = t.clone()
74+
mask_ignore = (t == 255)
75+
t_clean[mask_ignore] = 0
76+
y_true = one_hot(t_clean, num_classes=p_probs.shape[1])
77+
# Zero out ignored regions in one-hot ground truth
78+
if mask_ignore.any():
79+
y_true[mask_ignore.expand_as(y_true)] = 0
80+
else:
81+
# Standard overlap metrics use Argmax indices
82+
y_pred = p_probs.argmax(dim=1, keepdim=True)
83+
y_true = t
84+
85+
m(y_pred=y_pred, y=y_true)
86+
return m.aggregate()
87+
88+
m_ignore = metric_cls(ignore_index=255, **kwargs)
89+
v1_ign = eval_metric(m_ignore, p1, self.target_ignore)
90+
v2_ign = eval_metric(m_ignore, p2, self.target_ignore)
91+
self.log_result(name, v1_ign, v2_ign, True, "ignore")
92+
93+
m_std = metric_cls(**kwargs)
94+
v1_std = eval_metric(m_std, p1, self.target_std)
95+
v2_std = eval_metric(m_std, p2, self.target_std)
96+
self.log_result(name, v1_std, v2_std, False, "standard")
97+
98+
def run_focal_softmax_standard(self):
99+
input_1_mc = torch.cat([self.input_1, -self.input_1], dim=1)
100+
input_2_mc = torch.cat([self.input_2, -self.input_2], dim=1)
101+
102+
crit = FocalLoss(use_softmax=True, to_onehot_y=True)
103+
v1 = crit(input_1_mc, self.target_std)
104+
v2 = crit(input_2_mc, self.target_std)
105+
self.log_result("FocalLoss", v1, v2, False, "softmax")
106+
107+
def execute(self):
108+
print("--- Starting IgnoreIndex Tests ---")
109+
110+
self.run_loss_test("DiceLoss", DiceLoss, sigmoid=True)
111+
self.run_loss_test("TverskyLoss", TverskyLoss, sigmoid=True)
112+
self.run_loss_test("FocalLoss", FocalLoss, use_softmax=False)
113+
self.run_focal_softmax_standard()
114+
115+
self.run_metric_test("DiceMetric", DiceMetric, include_background=True)
116+
self.run_metric_test("MeanIoU", MeanIoU, include_background=True)
117+
self.run_metric_test("Accuracy", ConfusionMatrixMetric, metric_name="accuracy")
118+
119+
if importlib.util.find_spec("scipy") is not None:
120+
self.run_metric_test("Hausdorff", HausdorffDistanceMetric, include_background=True)
121+
122+
print("\n--- TEST SUMMARY ---")
123+
for r in self.results:
124+
print(r)
125+
126+
if __name__ == "__main__":
127+
with warnings.catch_warnings():
128+
warnings.simplefilter("ignore")
129+
IgnoreIndexTester().execute()

0 commit comments

Comments
 (0)