1+ import torch
2+ import warnings
3+ import importlib .util
4+ from monai .losses import DiceLoss , TverskyLoss , FocalLoss
5+ from monai .metrics import DiceMetric , MeanIoU , ConfusionMatrixMetric , HausdorffDistanceMetric
6+ from monai .networks import one_hot
7+
8+ class IgnoreIndexTester :
9+ def __init__ (self ):
10+ # Target: [Batch, Channel, H, W]
11+ # Bottom row is ignored (255)
12+ self .target_ignore = torch .tensor (
13+ [[[[1 , 0 ],
14+ [255 , 255 ]]]],
15+ dtype = torch .long ,
16+ )
17+
18+ self .target_std = torch .tensor (
19+ [[[[1 , 0 ],
20+ [0 , 1 ]]]],
21+ dtype = torch .long ,
22+ )
23+
24+ self .input_1 = torch .tensor (
25+ [[[[2.0 , - 2.0 ],
26+ [5.0 , - 5.0 ]]]]
27+ )
28+
29+ self .input_2 = torch .tensor (
30+ [[[[2.0 , - 2.0 ],
31+ [- 5.0 , 5.0 ]]]]
32+ )
33+
34+ self .results = []
35+
36+ def log_result (self , name , v1 , v2 , should_match , mode ):
37+ if isinstance (v1 , (list , tuple )):
38+ v1 = v1 [0 ]
39+ if isinstance (v2 , (list , tuple )):
40+ v2 = v2 [0 ]
41+
42+ match = torch .allclose (v1 .float (), v2 .float (), atol = 1e-5 , equal_nan = True )
43+ success = match == should_match
44+ status = "PASS" if success else "FAIL"
45+ self .results .append (f"{ name :18} [{ mode :8} ] : { status } " )
46+
47+ if not success :
48+ print (f"DEBUG { name } ({ mode } ) -> { v1 .item ():.6f} vs { v2 .item ():.6f} " )
49+
50+ def run_loss_test (self , name , loss_cls , ** kwargs ):
51+ crit_ignore = loss_cls (ignore_index = 255 , ** kwargs )
52+ v1_ign = crit_ignore (self .input_1 , self .target_ignore )
53+ v2_ign = crit_ignore (self .input_2 , self .target_ignore )
54+ self .log_result (name , v1_ign , v2_ign , True , "ignore" )
55+
56+ crit_std = loss_cls (** kwargs )
57+ v1_std = crit_std (self .input_1 , self .target_std )
58+ v2_std = crit_std (self .input_2 , self .target_std )
59+ self .log_result (name , v1_std , v2_std , False , "standard" )
60+
61+ def run_metric_test (self , name , metric_cls , ** kwargs ):
62+ # Probabilities for classes 0 and 1
63+ p1 = torch .cat ([1 - torch .sigmoid (self .input_1 ), torch .sigmoid (self .input_1 )], dim = 1 )
64+ p2 = torch .cat ([1 - torch .sigmoid (self .input_2 ), torch .sigmoid (self .input_2 )], dim = 1 )
65+
66+ def eval_metric (m , p_probs , t ):
67+ m .reset ()
68+
69+ # Hausdorff and distance metrics usually require One-Hot inputs
70+ if "Hausdorff" in name :
71+ y_pred = (p_probs > 0.5 ).float ()
72+ # Clean target for one_hot conversion
73+ t_clean = t .clone ()
74+ mask_ignore = (t == 255 )
75+ t_clean [mask_ignore ] = 0
76+ y_true = one_hot (t_clean , num_classes = p_probs .shape [1 ])
77+ # Zero out ignored regions in one-hot ground truth
78+ if mask_ignore .any ():
79+ y_true [mask_ignore .expand_as (y_true )] = 0
80+ else :
81+ # Standard overlap metrics use Argmax indices
82+ y_pred = p_probs .argmax (dim = 1 , keepdim = True )
83+ y_true = t
84+
85+ m (y_pred = y_pred , y = y_true )
86+ return m .aggregate ()
87+
88+ m_ignore = metric_cls (ignore_index = 255 , ** kwargs )
89+ v1_ign = eval_metric (m_ignore , p1 , self .target_ignore )
90+ v2_ign = eval_metric (m_ignore , p2 , self .target_ignore )
91+ self .log_result (name , v1_ign , v2_ign , True , "ignore" )
92+
93+ m_std = metric_cls (** kwargs )
94+ v1_std = eval_metric (m_std , p1 , self .target_std )
95+ v2_std = eval_metric (m_std , p2 , self .target_std )
96+ self .log_result (name , v1_std , v2_std , False , "standard" )
97+
98+ def run_focal_softmax_standard (self ):
99+ input_1_mc = torch .cat ([self .input_1 , - self .input_1 ], dim = 1 )
100+ input_2_mc = torch .cat ([self .input_2 , - self .input_2 ], dim = 1 )
101+
102+ crit = FocalLoss (use_softmax = True , to_onehot_y = True )
103+ v1 = crit (input_1_mc , self .target_std )
104+ v2 = crit (input_2_mc , self .target_std )
105+ self .log_result ("FocalLoss" , v1 , v2 , False , "softmax" )
106+
107+ def execute (self ):
108+ print ("--- Starting IgnoreIndex Tests ---" )
109+
110+ self .run_loss_test ("DiceLoss" , DiceLoss , sigmoid = True )
111+ self .run_loss_test ("TverskyLoss" , TverskyLoss , sigmoid = True )
112+ self .run_loss_test ("FocalLoss" , FocalLoss , use_softmax = False )
113+ self .run_focal_softmax_standard ()
114+
115+ self .run_metric_test ("DiceMetric" , DiceMetric , include_background = True )
116+ self .run_metric_test ("MeanIoU" , MeanIoU , include_background = True )
117+ self .run_metric_test ("Accuracy" , ConfusionMatrixMetric , metric_name = "accuracy" )
118+
119+ if importlib .util .find_spec ("scipy" ) is not None :
120+ self .run_metric_test ("Hausdorff" , HausdorffDistanceMetric , include_background = True )
121+
122+ print ("\n --- TEST SUMMARY ---" )
123+ for r in self .results :
124+ print (r )
125+
126+ if __name__ == "__main__" :
127+ with warnings .catch_warnings ():
128+ warnings .simplefilter ("ignore" )
129+ IgnoreIndexTester ().execute ()
0 commit comments