1- from .sensitivity_indices import sensitivity_indices
1+ from dataclasses import dataclass
2+ import logging
3+
4+ import matplotlib .pyplot as plt
25import numpy as np
36import pandas as pd
4- import matplotlib .pyplot as plt
57
6- __all__ = ["heterogeneity_indices" ]
8+ import simdec as sd
9+
10+ logger = logging .getLogger (__name__ )
11+
12+ __all__ = ["heterogeneity_indices" , "plot_heterogeneity" ]
13+
14+
15+ @dataclass
16+ class HeterogeneityResult :
17+ summary : pd .DataFrame
18+ regional_profiles : pd .DataFrame
19+ split_name : str
720
821
922def heterogeneity_indices (
@@ -12,9 +25,11 @@ def heterogeneity_indices(
1225 split_variable : str | pd .Series ,
1326 n_subdivisions : int | None = None ,
1427 plot : bool = False ,
15- ) -> pd .DataFrame :
16- """
17- Compute sensitivity-based heterogeneity across subdivisions of a variable.
28+ ) -> HeterogeneityResult :
29+ """Heterogeneity indices.
30+
31+ Compute sensitivity-based heterogeneity across subdivisions
32+ of a variable.
1833
1934 Parameters
2035 ----------
@@ -27,12 +42,25 @@ def heterogeneity_indices(
2742 n_subdivisions : int, optional
2843 Number of regions for continuous variables. Defaults to 4.
2944 plot : bool, default False
30- If True, displays a stacked bar chart of regional sensitivities.
45+ If True, displays a stacked bar chart of regional sensitivity profiles
46+ by calling :func:`plot_heterogeneity`. The chart shows variance
47+ contributions of each input across subdivisions of ``split_variable``,
48+ ranked by global sensitivity indices. To capture the returned
49+ ``matplotlib.axes.Axes`` object, call :func:`plot_heterogeneity`
50+ directly on the result instead.
3151
3252 Returns
33- ----------
34- summary : pd.Dataframe
35- A summary of calculated heterogeneity indices.
53+ -------
54+ res : HeterogeneityResult
55+ An object with attributes:
56+
57+ summary : DataFrame
58+ A summary of calculated heterogeneity indices.
59+ regional_profiles : DataFrame
60+ Regional sensitivity indices for each input across subdivisions.
61+ split_name : str
62+ The name of the variable used to split the data.
63+
3664 """
3765 y = pd .Series (output ).reset_index (drop = True )
3866 X = pd .DataFrame (inputs ).reset_index (drop = True )
@@ -51,8 +79,9 @@ def heterogeneity_indices(
5179
5280 # Determine if variable is categorical/binary
5381 is_categorical = (
54- pd . api . types . is_categorical_dtype ( z )
82+ isinstance ( z . dtype , pd . CategoricalDtype )
5583 or pd .api .types .is_object_dtype (z )
84+ or pd .api .types .is_string_dtype (z )
5685 or pd .api .types .is_bool_dtype (z )
5786 or n_unique <= 2
5887 )
@@ -89,7 +118,7 @@ def heterogeneity_indices(
89118 continue
90119
91120 try :
92- res = sensitivity_indices (inputs = X_sub , output = y_sub )
121+ res = sd . sensitivity_indices (inputs = X_sub , output = y_sub )
93122 si_vals = np .asarray (res .si ).ravel ()
94123
95124 # Guard against NaN/Inf from degenerate sensitivity computation
@@ -105,11 +134,9 @@ def heterogeneity_indices(
105134 continue
106135
107136 if skipped :
108- print (
109- f"[heterogeneity_indices] Skipped { len (skipped )} region(s) of '{ split_name } ':"
110- )
137+ logger .info ("Skipped %d region(s) of '%s':" , len (skipped ), split_name )
111138 for reg , n , reason in skipped :
112- print ( f " - region={ reg !r } , n={ n } , reason={ reason } " )
139+ logger . info ( " - region=%r , n=%d , reason=%s" , reg , n , reason )
113140
114141 if len (regional_profiles ) < 2 :
115142 total_regions = len (regions .cat .categories )
@@ -118,15 +145,15 @@ def heterogeneity_indices(
118145 f"Not enough valid subdivisions to compute heterogeneity: "
119146 f"{ valid } /{ total_regions } regions passed all checks for '{ split_name } '.\n "
120147 f"Skipped regions:\n "
121- + "\n " .join (f" { r !r} : n={ n } , { reason } " for r , n , reason in skipped )
122- + "\n \n Try: (1) reducing n_subdivisions, "
148+ "\n " .join (f" { r !r} : n={ n } , { reason } " for r , n , reason in skipped ),
149+ "\n \n Try: (1) reducing n_subdivisions, "
123150 "(2) using a different split_variable, or "
124- "(3) ensuring more samples per region."
151+ "(3) ensuring more samples per region." ,
125152 )
126153
127154 regional_si = pd .concat (regional_profiles , axis = 1 )
128155
129- res_global = sensitivity_indices (inputs = X , output = y )
156+ res_global = sd . sensitivity_indices (inputs = X , output = y )
130157 overall_si = pd .Series (
131158 np .asarray (res_global .si ).ravel (),
132159 index = X .columns ,
@@ -143,29 +170,70 @@ def heterogeneity_indices(
143170 ).sort_values (by = hetero_col_name , ascending = False )
144171 summary .loc ["SUM / TOTAL" ] = [overall_si .sum (), total_hetero ]
145172
173+ result = HeterogeneityResult (summary , regional_si , split_name )
174+
146175 if plot :
147- plot_order = summary .index [:- 1 ]
148- data_to_plot = regional_si .loc [plot_order ].T
149-
150- cmap = plt .get_cmap ("terrain" )
151- colors = [cmap (i ) for i in np .linspace (0.05 , 0.95 , len (plot_order ))]
152-
153- _ = data_to_plot .plot (
154- kind = "bar" ,
155- stacked = True ,
156- figsize = (10 , 6 ),
157- color = colors ,
158- edgecolor = "white" ,
159- width = 0.8 ,
160- )
176+ plot_heterogeneity (result )
177+
178+ return result
179+
180+
181+ def plot_heterogeneity (result : HeterogeneityResult , ax : plt .Axes = None ) -> plt .Axes :
182+ """Plot regional sensitivity profiles.
183+
184+ Parameters
185+ ----------
186+ result : HeterogeneityResult
187+ The result object from heterogeneity_indices.
188+ ax : matplotlib.axes.Axes, optional
189+ Existing axes to plot on.
190+
191+ Returns
192+ -------
193+ ax : matplotlib.axes.Axes
194+ The axes with the plot.
195+
196+ """
197+ summary = result .summary
198+ regional_si = result .regional_profiles
199+ split_name = result .split_name
200+
201+ plot_order = summary .index [summary .index != "SUM / TOTAL" ]
202+ plot_order = (
203+ summary .loc [plot_order ].sort_values (by = "Overall_SI" , ascending = False ).index
204+ )
205+
206+ cmap = plt .colormaps ["terrain" ]
207+ colors = [cmap (i ) for i in np .linspace (0.05 , 0.95 , len (regional_si .index ))]
208+
209+ data_to_plot = regional_si .loc [plot_order ].T
210+
211+ if ax is None :
212+ _ , ax = plt .subplots (figsize = (10 , 6 ))
213+
214+ data_to_plot .plot (
215+ kind = "bar" ,
216+ stacked = True ,
217+ ax = ax ,
218+ color = colors ,
219+ edgecolor = "white" ,
220+ width = 0.8 ,
221+ )
222+
223+ ax .set_title (f"Sensitivity Profiles across { split_name } " , fontsize = 14 )
224+ ax .set_ylabel ("Variance Contribution" , fontsize = 12 )
225+ ax .set_xlabel (f"Regions of { split_name } " , fontsize = 12 )
226+
227+ ax .legend (
228+ title = "Inputs (Ranked by Global SI)" ,
229+ bbox_to_anchor = (1.05 , 1 ),
230+ loc = "upper left" ,
231+ )
232+
233+ ax .tick_params (axis = "x" , labelrotation = 45 )
234+ ax .grid (axis = "y" , linestyle = "--" , alpha = 0.7 )
161235
162- plt .title (f"Sensitivity Profiles across { split_name } " , fontsize = 14 )
163- plt .ylabel ("Variance Contribution" , fontsize = 12 )
164- plt .xlabel (f"Regions of { split_name } " , fontsize = 12 )
165- plt .legend (title = "Input Variables" , bbox_to_anchor = (1.05 , 1 ), loc = "upper left" )
166- plt .xticks (rotation = 45 )
167- plt .grid (axis = "y" , linestyle = "--" , alpha = 0.7 )
236+ if plt .get_backend ().lower () != "agg" :
168237 plt .tight_layout ()
169- plt .show ()
170238
171- return summary
239+ return ax
0 commit comments