Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
490 changes: 254 additions & 236 deletions docs/notebooks/structural_reliability.ipynb

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions panel/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -203,34 +203,36 @@
<fast-text-field id="search-input" placeholder="search" onInput="hideCards(event.target.value)"></fast-text-field>
</section>

<section id="cards">
<section id="cards">
<ul class="cards-grid">
<!-- Sampling card moved to first position -->
<li class="card">
<a class="card-link" href="./simdec_app.html" id="simdec_app">
<a class="card-link" href="./sampling.html" id="sampling">
<fast-card class="gallery-item">
<object data="_static/thumbnails/simdec_app.png" type="image/png">
<object data="_static/thumbnails/sampling.png" type="image/png">
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
</svg>
</object>
<div class="card-content">
<h2 class="card-header">SimDec App</h2>
<h2 class="card-header">Sampling</h2>
</div>
</fast-card>
</a>
</li>
<!-- SimDec App card moved to second position -->
<li class="card">
<a class="card-link" href="./sampling.html" id="sampling">
<a class="card-link" href="./simdec_app.html" id="simdec_app">
<fast-card class="gallery-item">
<object data="_static/thumbnails/sampling.png" type="image/png">
<object data="_static/thumbnails/simdec_app.png" type="image/png">
<svg class="card-image" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M2.5 4a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1zm2-.5a.5.5 0 1 1-1 0 .5.5 0 0 1 1 0zm1 .5a.5.5 0 1 0 0-1 .5.5 0 0 0 0 1z"/>
<path d="M2 1a2 2 0 0 0-2 2v10a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V3a2 2 0 0 0-2-2H2zm13 2v2H1V3a1 1 0 0 1 1-1h12a1 1 0 0 1 1 1zM2 14a1 1 0 0 1-1-1V6h14v7a1 1 0 0 1-1 1H2z"/>
</svg>
</object>
<div class="card-content">
<h2 class="card-header">Sampling</h2>
<h2 class="card-header">SimDec App</h2>
</div>
</fast-card>
</a>
Expand Down
10 changes: 6 additions & 4 deletions panel/simdec_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,11 @@ def filtered_si(sensitivity_indices_table, input_names):


def explained_variance_80(sensitivity_indices_table):
si = sensitivity_indices_table.value["Indices"]
pos_80 = bisect.bisect_right(np.cumsum(si), 0.8)
df = sensitivity_indices_table.value
df = df[df["Inputs"] != "Sum of Indices"]
si = df["Indices"].values
target = 0.8 * np.sum(si)
pos_80 = bisect.bisect_right(np.cumsum(si), target)

# pos_80 = max(2, pos_80)
# pos_80 = min(len(si), pos_80)
Expand Down Expand Up @@ -225,9 +228,8 @@ def create_color_pickers(states, colors):
@pn.cache
def palette_(states: list[list[str]], colors_picked: list[list[float]]):
cmaps = [single_color_to_colormap(color_picked) for color_picked in colors_picked]
# Reverse order as in figures high values take the first colors
states = [len(states_) for states_ in states]
return sd.palette(states, cmaps=cmaps[::-1])
return sd.palette(states, cmaps=cmaps)


@pn.cache
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ dashboard = [
"cryptography",
]

display = [
"ipython>=9.1"
]

test = [
"pytest",
"pytest-cov",
Expand All @@ -55,7 +59,7 @@ doc = [
]

dev = [
"simdec[doc,test,dashboard]",
"simdec[doc,test,dashboard, display]",
"watchfiles",
"pre-commit",
]
Expand Down
2 changes: 2 additions & 0 deletions src/simdec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""SimDec main namespace."""
from simdec.decomposition import *
from simdec.heterogeneity_indices import *
from simdec.sensitivity_indices import *
from simdec.visualization import *

Expand All @@ -11,4 +12,5 @@
"two_output_visualization",
"tableau",
"palette",
"heterogeneity_indices",
]
247 changes: 247 additions & 0 deletions src/simdec/heterogeneity_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from dataclasses import dataclass
import logging

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import simdec as sd

logger = logging.getLogger(__name__)

__all__ = ["heterogeneity_indices", "plot_heterogeneity"]


@dataclass
class HeterogeneityResult:
summary: pd.DataFrame
regional_profiles: pd.DataFrame
split_name: str


def heterogeneity_indices(
output: pd.Series,
inputs: pd.DataFrame,
split_variable: str | pd.Series,
n_subdivisions: int | None = None,
plot: bool = False,
) -> HeterogeneityResult:
"""Heterogeneity indices.

Compute sensitivity-based heterogeneity across subdivisions
of a variable.

Parameters
----------
output : pd.Series
Model output vector.
inputs : pd.DataFrame
Input/feature matrix.
split_variable : str or pd.Series
Variable to split on. If string, must be a column in 'inputs'.
n_subdivisions : int, optional
Number of regions for continuous variables. Defaults to 4.
plot : bool, default False
If True, displays a stacked bar chart of regional sensitivity profiles
by calling :func:`plot_heterogeneity`. The chart shows variance
contributions of each input across subdivisions of ``split_variable``,
ranked by global sensitivity indices. To capture the returned
``matplotlib.axes.Axes`` object, call :func:`plot_heterogeneity`
directly on the result instead.

Returns
-------
res : HeterogeneityResult
An object with attributes:

summary : DataFrame
A summary of calculated heterogeneity indices.
regional_profiles : DataFrame
Regional sensitivity indices for each input across subdivisions.
split_name : str
The name of the variable used to split the data.

"""
y = pd.Series(output).reset_index(drop=True)
X = pd.DataFrame(inputs).reset_index(drop=True)

if isinstance(split_variable, str):
if split_variable not in X.columns:
raise ValueError(f"'{split_variable}' not found in inputs.")
z = X[split_variable].reset_index(drop=True)
split_name = split_variable
else:
z = pd.Series(split_variable).reset_index(drop=True)
split_name = getattr(split_variable, "name", "split_variable")

unique_vals = z.dropna().unique()
n_unique = len(unique_vals)

# Determine if variable is categorical/binary
is_categorical = (
isinstance(z.dtype, pd.CategoricalDtype)
or pd.api.types.is_object_dtype(z)
or pd.api.types.is_string_dtype(z)
or pd.api.types.is_bool_dtype(z)
or n_unique <= 2
)

if is_categorical:
regions = z.astype("category")
else:
q = n_subdivisions if n_subdivisions is not None else 4
try:
regions = pd.qcut(z, q=q, duplicates="drop")
except ValueError as e:
raise ValueError(
f"Failed to bin '{split_name}' into {q} quantiles: {e}"
) from e

regional_profiles = []
skipped = []

for region in regions.cat.categories:
mask = regions == region
n_in_region = mask.sum()

if n_in_region < 10:
# Need enough samples for meaningful sensitivity indices
skipped.append((region, n_in_region, "too few samples (< 10)"))
continue

X_sub = X.loc[mask]
y_sub = y.loc[mask]

# Skip if output has zero or near-zero variance in this region
if y_sub.var() < 1e-12:
skipped.append((region, n_in_region, "output variance ≈ 0"))
continue

try:
res = sd.sensitivity_indices(inputs=X_sub, output=y_sub)
si_vals = np.asarray(res.si).ravel()

# Guard against NaN/Inf from degenerate sensitivity computation
if not np.all(np.isfinite(si_vals)):
skipped.append((region, n_in_region, "non-finite SI values"))
continue

si_region = pd.Series(si_vals, index=X.columns, name=region)
regional_profiles.append(si_region)

except Exception as e:
skipped.append((region, n_in_region, f"exception: {e}"))
continue

if skipped:
logger.info("Skipped %d region(s) of '%s':", len(skipped), split_name)
for reg, n, reason in skipped:
logger.info(" - region=%r, n=%d, reason=%s", reg, n, reason)

if len(regional_profiles) < 2:
total_regions = len(regions.cat.categories)
valid = len(regional_profiles)
raise ValueError(
f"Not enough valid subdivisions to compute heterogeneity: "
f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
f"Skipped regions:\n"
"\n".join(f" {r!r}: n={n}, {reason} " for r, n, reason in skipped),
"\n\nTry: (1) reducing n_subdivisions, "
"(2) using a different split_variable, or "
"(3) ensuring more samples per region.",
)

regional_si = pd.concat(regional_profiles, axis=1)

res_global = sd.sensitivity_indices(inputs=X, output=y)
overall_si = pd.Series(
np.asarray(res_global.si).ravel(),
index=X.columns,
name="Overall_SI",
)

# Heterogeneity = 2 × population std dev across regions
hetero_scores = 2 * regional_si.std(axis=1, ddof=0)
total_hetero = hetero_scores.mean()

hetero_col_name = f"Heterogeneity (across {split_name})"
summary = pd.DataFrame(
{"Overall_SI": overall_si, hetero_col_name: hetero_scores}
).sort_values(by=hetero_col_name, ascending=False)
summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]

result = HeterogeneityResult(summary, regional_si, split_name)

if plot:
Comment thread
tupui marked this conversation as resolved.
plot_heterogeneity(result)

return result


def plot_heterogeneity(result: HeterogeneityResult, ax: plt.Axes = None) -> plt.Axes:
"""Plot regional sensitivity profiles.

Parameters
----------
result : HeterogeneityResult
The result object from heterogeneity_indices.
ax : matplotlib.axes.Axes, optional
Existing axes to plot on.

Returns
-------
ax : matplotlib.axes.Axes
The axes with the plot.

"""
summary = result.summary
regional_si = result.regional_profiles
split_name = result.split_name

hetero_col_name = [c for c in summary.columns if "Heterogeneity" in c][0]
total_hetero = summary.loc["SUM / TOTAL", hetero_col_name]

plot_order = summary.index[summary.index != "SUM / TOTAL"]
plot_order = (
summary.loc[plot_order].sort_values(by="Overall_SI", ascending=False).index
)

cmap = plt.colormaps["terrain"]
colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(regional_si.index))]

data_to_plot = regional_si.loc[plot_order].T

if ax is None:
_, ax = plt.subplots(figsize=(10, 6))

data_to_plot.plot(
kind="bar",
stacked=True,
ax=ax,
color=colors,
edgecolor="white",
width=0.8,
)

ax.set_title(
f"Sensitivity Profiles across {split_name}\n"
f"(Total Heterogeneity: {total_hetero:.3f})",
fontsize=10,
)

ax.set_ylabel("Variance Contribution", fontsize=8)
ax.set_xlabel(f"Regions of {split_name}", fontsize=8)

ax.legend(
title="Inputs (Ranked by Global SI)",
bbox_to_anchor=(1.05, 1),
loc="upper left",
)

ax.tick_params(axis="x", labelrotation=45)
ax.grid(axis="y", linestyle="--", alpha=0.7)

if plt.get_backend().lower() != "agg":
plt.tight_layout()

return ax
Loading