Skip to content

Commit 084ea22

Browse files
committed
Add heterogeneity indices, printing indices/legend, fix index.html.
- Add heterogeneity indices function. - Add functionality for printing indices and legend. sensitivity_indices.py edited to get var_names for printing. - Added guardrails for output parameter in decomposition.py - Added tests for printing legend. - Initialize decomposition in dashboard with 0.8*sum(si). Other variables still can be chosen after this. - Updated docs to print correct second-order effects. Closes #46, Closes #47
1 parent 21b8623 commit 084ea22

9 files changed

Lines changed: 599 additions & 252 deletions

File tree

docs/notebooks/structural_reliability.ipynb

Lines changed: 254 additions & 236 deletions
Large diffs are not rendered by default.

panel/index.html

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,34 +203,36 @@
203203
<fast-text-field id="search-input" placeholder="search" onInput="hideCards(event.target.value)"></fast-text-field>
204204
</section>
205205

206-
<section id="cards">
206+
<section id="cards">
207207
<ul class="cards-grid">
208+
<!-- Sampling card moved to first position -->
208209
<li class="card">
209-
<a class="card-link" href="./simdec_app.html" id="simdec_app">
210+
<a class="card-link" href="./sampling.html" id="sampling">
210211
<fast-card class="gallery-item">
211-
<object data="_static/thumbnails/simdec_app.png" type="image/png">
212+
<object data="_static/thumbnails/sampling.png" type="image/png">
212213
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
213214
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
214215
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
215216
</svg>
216217
</object>
217218
<div class="card-content">
218-
<h2 class="card-header">SimDec App</h2>
219+
<h2 class="card-header">Sampling</h2>
219220
</div>
220221
</fast-card>
221222
</a>
222223
</li>
224+
<!-- SimDec App card moved to second position -->
223225
<li class="card">
224-
<a class="card-link" href="./sampling.html" id="sampling">
226+
<a class="card-link" href="./simdec_app.html" id="simdec_app">
225227
<fast-card class="gallery-item">
226-
<object data="_static/thumbnails/sampling.png" type="image/png">
228+
<object data="_static/thumbnails/simdec_app.png" type="image/png">
227229
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
228230
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
229231
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
230232
</svg>
231233
</object>
232234
<div class="card-content">
233-
<h2 class="card-header">Sampling</h2>
235+
<h2 class="card-header">SimDec App</h2>
234236
</div>
235237
</fast-card>
236238
</a>

panel/simdec_app.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,11 @@ def filtered_si(sensitivity_indices_table, input_names):
170170

171171

172172
def explained_variance_80(sensitivity_indices_table):
173-
si = sensitivity_indices_table.value["Indices"]
174-
pos_80 = bisect.bisect_right(np.cumsum(si), 0.8)
173+
df = sensitivity_indices_table.value
174+
df = df[df["Inputs"] != "Sum of Indices"]
175+
si = df["Indices"].values
176+
target = 0.8 * np.sum(si)
177+
pos_80 = bisect.bisect_right(np.cumsum(si), target)
175178

176179
# pos_80 = max(2, pos_80)
177180
# pos_80 = min(len(si), pos_80)

src/simdec/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from simdec.decomposition import *
33
from simdec.sensitivity_indices import *
44
from simdec.visualization import *
5+
from simdec.heterogeneity_indices import *
56

67
__all__ = [
78
"sensitivity_indices",
@@ -11,4 +12,5 @@
1112
"two_output_visualization",
1213
"tableau",
1314
"palette",
15+
"heterogeneity_indices",
1416
]

src/simdec/decomposition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __reduce__(self):
6565

6666
def decomposition(
6767
inputs: pd.DataFrame,
68-
output: pd.DataFrame,
68+
output: pd.DataFrame | np.ndarray,
6969
*,
7070
sensitivity_indices: np.ndarray,
7171
dec_limit: float | None = None,
@@ -116,7 +116,11 @@ def decomposition(
116116
inputs[cat_col] = codes
117117

118118
inputs = inputs.to_numpy()
119-
output = output.to_numpy().flatten()
119+
120+
if hasattr(output, "to_numpy"):
121+
output = output.to_numpy().flatten()
122+
else:
123+
output = np.asarray(output).flatten()
120124

121125
# 1. variables for decomposition
122126
var_order = np.argsort(sensitivity_indices)[::-1]
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from .sensitivity_indices import sensitivity_indices
2+
import numpy as np
3+
import pandas as pd
4+
import matplotlib.pyplot as plt
5+
6+
__all__ = ["heterogeneity_indices"]
7+
8+
9+
def heterogeneity_indices(
10+
output: pd.Series,
11+
inputs: pd.DataFrame,
12+
split_variable: str | pd.Series,
13+
n_subdivisions: int | None = None,
14+
plot: bool = False,
15+
) -> pd.DataFrame:
16+
"""
17+
Compute sensitivity-based heterogeneity across subdivisions of a variable.
18+
19+
Parameters
20+
----------
21+
output : pd.Series
22+
Model output vector.
23+
inputs : pd.DataFrame
24+
Input/feature matrix.
25+
split_variable : str or pd.Series
26+
Variable to split on. If string, must be a column in 'inputs'.
27+
n_subdivisions : int, optional
28+
Number of regions for continuous variables. Defaults to 4.
29+
plot : bool, default False
30+
If True, displays a stacked bar chart of regional sensitivities.
31+
32+
Returns
33+
----------
34+
summary : pd.Dataframe
35+
A summary of calculated heterogeneity indices.
36+
"""
37+
y = pd.Series(output).reset_index(drop=True)
38+
X = pd.DataFrame(inputs).reset_index(drop=True)
39+
40+
if isinstance(split_variable, str):
41+
if split_variable not in X.columns:
42+
raise ValueError(f"'{split_variable}' not found in inputs.")
43+
z = X[split_variable].reset_index(drop=True)
44+
split_name = split_variable
45+
else:
46+
z = pd.Series(split_variable).reset_index(drop=True)
47+
split_name = getattr(split_variable, "name", "split_variable")
48+
49+
unique_vals = z.dropna().unique()
50+
n_unique = len(unique_vals)
51+
52+
# Determine if variable is categorical/binary
53+
is_categorical = (
54+
pd.api.types.is_categorical_dtype(z)
55+
or pd.api.types.is_object_dtype(z)
56+
or pd.api.types.is_bool_dtype(z)
57+
or n_unique <= 2
58+
)
59+
60+
if is_categorical:
61+
regions = z.astype("category")
62+
else:
63+
q = n_subdivisions if n_subdivisions is not None else 4
64+
try:
65+
regions = pd.qcut(z, q=q, duplicates="drop")
66+
except ValueError as e:
67+
raise ValueError(
68+
f"Failed to bin '{split_name}' into {q} quantiles: {e}"
69+
) from e
70+
71+
regional_profiles = []
72+
skipped = []
73+
74+
for region in regions.cat.categories:
75+
mask = regions == region
76+
n_in_region = mask.sum()
77+
78+
if n_in_region < 10:
79+
# Need enough samples for meaningful sensitivity indices
80+
skipped.append((region, n_in_region, "too few samples (< 10)"))
81+
continue
82+
83+
X_sub = X.loc[mask]
84+
y_sub = y.loc[mask]
85+
86+
# Skip if output has zero or near-zero variance in this region
87+
if y_sub.var() < 1e-12:
88+
skipped.append((region, n_in_region, "output variance ≈ 0"))
89+
continue
90+
91+
try:
92+
res = sensitivity_indices(inputs=X_sub, output=y_sub)
93+
si_vals = np.asarray(res.si).ravel()
94+
95+
# Guard against NaN/Inf from degenerate sensitivity computation
96+
if not np.all(np.isfinite(si_vals)):
97+
skipped.append((region, n_in_region, "non-finite SI values"))
98+
continue
99+
100+
si_region = pd.Series(si_vals, index=X.columns, name=region)
101+
regional_profiles.append(si_region)
102+
103+
except Exception as e:
104+
skipped.append((region, n_in_region, f"exception: {e}"))
105+
continue
106+
107+
if skipped:
108+
print(
109+
f"[heterogeneity_indices] Skipped {len(skipped)} region(s) of '{split_name}':"
110+
)
111+
for reg, n, reason in skipped:
112+
print(f" - region={reg!r}, n={n}, reason={reason}")
113+
114+
if len(regional_profiles) < 2:
115+
total_regions = len(regions.cat.categories)
116+
valid = len(regional_profiles)
117+
raise ValueError(
118+
f"Not enough valid subdivisions to compute heterogeneity: "
119+
f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
120+
f"Skipped regions:\n"
121+
+ "\n".join(f" {r!r}: n={n}, {reason}" for r, n, reason in skipped)
122+
+ "\n\nTry: (1) reducing n_subdivisions, "
123+
"(2) using a different split_variable, or "
124+
"(3) ensuring more samples per region."
125+
)
126+
127+
regional_si = pd.concat(regional_profiles, axis=1)
128+
129+
res_global = sensitivity_indices(inputs=X, output=y)
130+
overall_si = pd.Series(
131+
np.asarray(res_global.si).ravel(),
132+
index=X.columns,
133+
name="Overall_SI",
134+
)
135+
136+
# Heterogeneity = 2 × population std dev across regions
137+
hetero_scores = 2 * regional_si.std(axis=1, ddof=0)
138+
total_hetero = hetero_scores.mean()
139+
140+
hetero_col_name = f"Heterogeneity (across {split_name})"
141+
summary = pd.DataFrame(
142+
{"Overall_SI": overall_si, hetero_col_name: hetero_scores}
143+
).sort_values(by=hetero_col_name, ascending=False)
144+
summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]
145+
146+
if plot:
147+
plot_order = summary.index[:-1]
148+
data_to_plot = regional_si.loc[plot_order].T
149+
150+
cmap = plt.get_cmap("terrain")
151+
colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(plot_order))]
152+
153+
_ = data_to_plot.plot(
154+
kind="bar",
155+
stacked=True,
156+
figsize=(10, 6),
157+
color=colors,
158+
edgecolor="white",
159+
width=0.8,
160+
)
161+
162+
plt.title(f"Sensitivity Profiles across {split_name}", fontsize=14)
163+
plt.ylabel("Variance Contribution", fontsize=12)
164+
plt.xlabel(f"Regions of {split_name}", fontsize=12)
165+
plt.legend(title="Input Variables", bbox_to_anchor=(1.05, 1), loc="upper left")
166+
plt.xticks(rotation=45)
167+
plt.grid(axis="y", linestyle="--", alpha=0.7)
168+
plt.tight_layout()
169+
plt.show()
170+
171+
return summary

src/simdec/sensitivity_indices.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ class SensitivityAnalysisResult:
3737

3838

3939
def sensitivity_indices(
40-
inputs: pd.DataFrame | np.ndarray, output: pd.DataFrame | np.ndarray
40+
inputs: pd.DataFrame | np.ndarray,
41+
output: pd.DataFrame | np.ndarray,
42+
print_indices: bool = False,
4143
) -> SensitivityAnalysisResult:
4244
"""Sensitivity indices.
4345
@@ -50,6 +52,8 @@ def sensitivity_indices(
5052
Input variables.
5153
output : ndarray or DataFrame of shape (n_runs, 1)
5254
Target variable.
55+
print_indices : bool, default False
56+
If True, displays computed indices.
5357
5458
Returns
5559
-------
@@ -97,11 +101,18 @@ def sensitivity_indices(
97101
"""
98102
# Handle inputs conversion
99103
if isinstance(inputs, pd.DataFrame):
100-
cat_columns = inputs.select_dtypes(["category", "O"]).columns
101-
inputs[cat_columns] = inputs[cat_columns].apply(
102-
lambda x: x.astype("category").cat.codes
103-
)
104+
var_names = inputs.columns.tolist()
105+
cat_cols = inputs.select_dtypes(["category", "O"]).columns
106+
if not cat_cols.empty:
107+
inputs = inputs.copy() # Avoid SettingWithCopyWarning
108+
inputs[cat_cols] = inputs[cat_cols].apply(
109+
lambda x: x.astype("category").cat.codes
110+
)
104111
inputs = inputs.to_numpy()
112+
else:
113+
inputs = np.asarray(inputs)
114+
# Fallback names if it's just a numpy array
115+
var_names = [f"x{i}" for i in range(inputs.shape[1])]
105116

106117
# Handle output conversion first, then flatten
107118
if isinstance(output, (pd.DataFrame, pd.Series)):
@@ -181,4 +192,14 @@ def sensitivity_indices(
181192
for k in range(n_factors):
182193
si[k] = foe[k] + (soe[:, k].sum() / 2)
183194

195+
if print_indices:
196+
df_foe = pd.DataFrame(foe, index=var_names, columns=["First-order effect"])
197+
df_soe = pd.DataFrame(soe, index=var_names, columns=var_names)
198+
df_si = pd.DataFrame(si, index=var_names, columns=["Combined effect"])
199+
200+
df_indices = pd.concat([df_foe, df_soe, df_si], axis=1)
201+
print(f"{'-'*69}")
202+
print(df_indices)
203+
print(f"{'-'*69}")
204+
184205
return SensitivityAnalysisResult(si, foe, soe)

0 commit comments

Comments
 (0)