Skip to content

Commit 625c39e

Browse files
committed
Add test, typing info, nits, imports.
- Remove decomposition guardrail. - Use one print for print_indices. - Add typing info to visualization.py - Specify decomposition in visualization.py - Add Ipython into pyproject.toml and add guardrail for import. - Add stacklevel=2 into warnings. - Make a function for plotting heterogeneity indices. - Work with ax instead of plt. - Delete unnecessary +. - Add logging instead of print in heterogeneity_indices.py. - Set import order correct. Closes #46, Closes #47.
1 parent 084ea22 commit 625c39e

7 files changed

Lines changed: 282 additions & 70 deletions

File tree

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ dashboard = [
4141
"cryptography",
4242
]
4343

44+
ipython = [
45+
"ipython"
46+
]
47+
4448
test = [
4549
"pytest",
4650
"pytest-cov",
@@ -55,7 +59,7 @@ doc = [
5559
]
5660

5761
dev = [
58-
"simdec[doc,test,dashboard]",
62+
"simdec[doc,test,dashboard, ipython]",
5963
"watchfiles",
6064
"pre-commit",
6165
]

src/simdec/decomposition.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __reduce__(self):
6565

6666
def decomposition(
6767
inputs: pd.DataFrame,
68-
output: pd.DataFrame | np.ndarray,
68+
output: pd.DataFrame,
6969
*,
7070
sensitivity_indices: np.ndarray,
7171
dec_limit: float | None = None,
@@ -116,11 +116,7 @@ def decomposition(
116116
inputs[cat_col] = codes
117117

118118
inputs = inputs.to_numpy()
119-
120-
if hasattr(output, "to_numpy"):
121-
output = output.to_numpy().flatten()
122-
else:
123-
output = np.asarray(output).flatten()
119+
output = output.to_numpy().flatten()
124120

125121
# 1. variables for decomposition
126122
var_order = np.argsort(sensitivity_indices)[::-1]

src/simdec/heterogeneity_indices.py

Lines changed: 110 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1-
from .sensitivity_indices import sensitivity_indices
1+
from dataclasses import dataclass
2+
import logging
3+
4+
import matplotlib.pyplot as plt
25
import numpy as np
36
import pandas as pd
4-
import matplotlib.pyplot as plt
57

6-
__all__ = ["heterogeneity_indices"]
8+
import simdec as sd
9+
10+
logger = logging.getLogger(__name__)
11+
12+
__all__ = ["heterogeneity_indices", "plot_heterogeneity"]
13+
14+
15+
@dataclass
16+
class HeterogeneityResult:
17+
summary: pd.DataFrame
18+
regional_profiles: pd.DataFrame
19+
split_name: str
720

821

922
def heterogeneity_indices(
@@ -12,9 +25,11 @@ def heterogeneity_indices(
1225
split_variable: str | pd.Series,
1326
n_subdivisions: int | None = None,
1427
plot: bool = False,
15-
) -> pd.DataFrame:
16-
"""
17-
Compute sensitivity-based heterogeneity across subdivisions of a variable.
28+
) -> HeterogeneityResult:
29+
"""Heterogeneity indices.
30+
31+
Compute sensitivity-based heterogeneity across subdivisions
32+
of a variable.
1833
1934
Parameters
2035
----------
@@ -27,12 +42,25 @@ def heterogeneity_indices(
2742
n_subdivisions : int, optional
2843
Number of regions for continuous variables. Defaults to 4.
2944
plot : bool, default False
30-
If True, displays a stacked bar chart of regional sensitivities.
45+
If True, displays a stacked bar chart of regional sensitivity profiles
46+
by calling :func:`plot_heterogeneity`. The chart shows variance
47+
contributions of each input across subdivisions of ``split_variable``,
48+
ranked by global sensitivity indices. To capture the returned
49+
``matplotlib.axes.Axes`` object, call :func:`plot_heterogeneity`
50+
directly on the result instead.
3151
3252
Returns
33-
----------
34-
summary : pd.Dataframe
35-
A summary of calculated heterogeneity indices.
53+
-------
54+
res : HeterogeneityResult
55+
An object with attributes:
56+
57+
summary : DataFrame
58+
A summary of calculated heterogeneity indices.
59+
regional_profiles : DataFrame
60+
Regional sensitivity indices for each input across subdivisions.
61+
split_name : str
62+
The name of the variable used to split the data.
63+
3664
"""
3765
y = pd.Series(output).reset_index(drop=True)
3866
X = pd.DataFrame(inputs).reset_index(drop=True)
@@ -51,8 +79,9 @@ def heterogeneity_indices(
5179

5280
# Determine if variable is categorical/binary
5381
is_categorical = (
54-
pd.api.types.is_categorical_dtype(z)
82+
isinstance(z.dtype, pd.CategoricalDtype)
5583
or pd.api.types.is_object_dtype(z)
84+
or pd.api.types.is_string_dtype(z)
5685
or pd.api.types.is_bool_dtype(z)
5786
or n_unique <= 2
5887
)
@@ -89,7 +118,7 @@ def heterogeneity_indices(
89118
continue
90119

91120
try:
92-
res = sensitivity_indices(inputs=X_sub, output=y_sub)
121+
res = sd.sensitivity_indices(inputs=X_sub, output=y_sub)
93122
si_vals = np.asarray(res.si).ravel()
94123

95124
# Guard against NaN/Inf from degenerate sensitivity computation
@@ -105,11 +134,9 @@ def heterogeneity_indices(
105134
continue
106135

107136
if skipped:
108-
print(
109-
f"[heterogeneity_indices] Skipped {len(skipped)} region(s) of '{split_name}':"
110-
)
137+
logger.info("Skipped %d region(s) of '%s':", len(skipped), split_name)
111138
for reg, n, reason in skipped:
112-
print(f" - region={reg!r}, n={n}, reason={reason}")
139+
logger.info(" - region=%r, n=%d, reason=%s", reg, n, reason)
113140

114141
if len(regional_profiles) < 2:
115142
total_regions = len(regions.cat.categories)
@@ -118,15 +145,15 @@ def heterogeneity_indices(
118145
f"Not enough valid subdivisions to compute heterogeneity: "
119146
f"{valid}/{total_regions} regions passed all checks for '{split_name}'.\n"
120147
f"Skipped regions:\n"
121-
+ "\n".join(f" {r!r}: n={n}, {reason}" for r, n, reason in skipped)
122-
+ "\n\nTry: (1) reducing n_subdivisions, "
148+
"\n".join(f" {r!r}: n={n}, {reason} " for r, n, reason in skipped),
149+
"\n\nTry: (1) reducing n_subdivisions, "
123150
"(2) using a different split_variable, or "
124-
"(3) ensuring more samples per region."
151+
"(3) ensuring more samples per region.",
125152
)
126153

127154
regional_si = pd.concat(regional_profiles, axis=1)
128155

129-
res_global = sensitivity_indices(inputs=X, output=y)
156+
res_global = sd.sensitivity_indices(inputs=X, output=y)
130157
overall_si = pd.Series(
131158
np.asarray(res_global.si).ravel(),
132159
index=X.columns,
@@ -143,29 +170,70 @@ def heterogeneity_indices(
143170
).sort_values(by=hetero_col_name, ascending=False)
144171
summary.loc["SUM / TOTAL"] = [overall_si.sum(), total_hetero]
145172

173+
result = HeterogeneityResult(summary, regional_si, split_name)
174+
146175
if plot:
147-
plot_order = summary.index[:-1]
148-
data_to_plot = regional_si.loc[plot_order].T
149-
150-
cmap = plt.get_cmap("terrain")
151-
colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(plot_order))]
152-
153-
_ = data_to_plot.plot(
154-
kind="bar",
155-
stacked=True,
156-
figsize=(10, 6),
157-
color=colors,
158-
edgecolor="white",
159-
width=0.8,
160-
)
176+
plot_heterogeneity(result)
177+
178+
return result
179+
180+
181+
def plot_heterogeneity(result: HeterogeneityResult, ax: plt.Axes = None) -> plt.Axes:
182+
"""Plot regional sensitivity profiles.
183+
184+
Parameters
185+
----------
186+
result : HeterogeneityResult
187+
The result object from heterogeneity_indices.
188+
ax : matplotlib.axes.Axes, optional
189+
Existing axes to plot on.
190+
191+
Returns
192+
-------
193+
ax : matplotlib.axes.Axes
194+
The axes with the plot.
195+
196+
"""
197+
summary = result.summary
198+
regional_si = result.regional_profiles
199+
split_name = result.split_name
200+
201+
plot_order = summary.index[summary.index != "SUM / TOTAL"]
202+
plot_order = (
203+
summary.loc[plot_order].sort_values(by="Overall_SI", ascending=False).index
204+
)
205+
206+
cmap = plt.colormaps["terrain"]
207+
colors = [cmap(i) for i in np.linspace(0.05, 0.95, len(regional_si.index))]
208+
209+
data_to_plot = regional_si.loc[plot_order].T
210+
211+
if ax is None:
212+
_, ax = plt.subplots(figsize=(10, 6))
213+
214+
data_to_plot.plot(
215+
kind="bar",
216+
stacked=True,
217+
ax=ax,
218+
color=colors,
219+
edgecolor="white",
220+
width=0.8,
221+
)
222+
223+
ax.set_title(f"Sensitivity Profiles across {split_name}", fontsize=14)
224+
ax.set_ylabel("Variance Contribution", fontsize=12)
225+
ax.set_xlabel(f"Regions of {split_name}", fontsize=12)
226+
227+
ax.legend(
228+
title="Inputs (Ranked by Global SI)",
229+
bbox_to_anchor=(1.05, 1),
230+
loc="upper left",
231+
)
232+
233+
ax.tick_params(axis="x", labelrotation=45)
234+
ax.grid(axis="y", linestyle="--", alpha=0.7)
161235

162-
plt.title(f"Sensitivity Profiles across {split_name}", fontsize=14)
163-
plt.ylabel("Variance Contribution", fontsize=12)
164-
plt.xlabel(f"Regions of {split_name}", fontsize=12)
165-
plt.legend(title="Input Variables", bbox_to_anchor=(1.05, 1), loc="upper left")
166-
plt.xticks(rotation=45)
167-
plt.grid(axis="y", linestyle="--", alpha=0.7)
236+
if plt.get_backend().lower() != "agg":
168237
plt.tight_layout()
169-
plt.show()
170238

171-
return summary
239+
return ax

src/simdec/sensitivity_indices.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def sensitivity_indices(
102102
# Handle inputs conversion
103103
if isinstance(inputs, pd.DataFrame):
104104
var_names = inputs.columns.tolist()
105-
cat_cols = inputs.select_dtypes(["category", "O"]).columns
105+
cat_cols = inputs.select_dtypes(include=["category", "O", "string"]).columns
106106
if not cat_cols.empty:
107107
inputs = inputs.copy() # Avoid SettingWithCopyWarning
108108
inputs[cat_cols] = inputs[cat_cols].apply(
@@ -198,8 +198,6 @@ def sensitivity_indices(
198198
df_si = pd.DataFrame(si, index=var_names, columns=["Combined effect"])
199199

200200
df_indices = pd.concat([df_foe, df_soe, df_si], axis=1)
201-
print(f"{'-'*69}")
202-
print(df_indices)
203-
print(f"{'-'*69}")
201+
print(f"\n{df_indices}\n")
204202

205203
return SensitivityAnalysisResult(si, foe, soe)

src/simdec/visualization.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,18 @@
1010
import seaborn as sns
1111
import pandas as pd
1212
from pandas.io.formats.style import Styler
13+
import warnings
14+
15+
from simdec.decomposition import DecompositionResult
1316

1417
__all__ = ["visualization", "two_output_visualization", "tableau", "palette"]
1518

19+
try:
20+
from IPython.display import display
21+
22+
HAS_IPYTHON = True
23+
except ImportError:
24+
HAS_IPYTHON = False
1625

1726
SEQUENTIAL_PALETTES = [
1827
"#DC267F",
@@ -140,7 +149,7 @@ def visualization(
140149
kind: Literal["histogram", "boxplot"] = "histogram",
141150
ax=None,
142151
print_legend: bool = False,
143-
decomposition=None,
152+
decomposition: DecompositionResult | None = None,
144153
) -> plt.Axes:
145154
"""Histogram plot of scenarios.
146155
@@ -158,7 +167,7 @@ def visualization(
158167
Matplotlib axis.
159168
print_legend: Boolean, optional
160169
Prints plot legend.
161-
decomposition: Object, optional
170+
decomposition: DecompositionResult, optional
162171
Required for print_legend.
163172
164173
Returns
@@ -194,13 +203,16 @@ def visualization(
194203
raise ValueError("'kind' can only be 'histogram' or 'boxplot'")
195204

196205
if print_legend:
197-
from IPython.display import display
198-
199-
if decomposition is None:
200-
import warnings
201-
206+
if not HAS_IPYTHON:
207+
warnings.warn(
208+
"print_legend=True requires ipython to be installed. "
209+
"Install it with: pip install simdec[ipython]",
210+
stacklevel=2,
211+
)
212+
elif decomposition is None:
202213
warnings.warn(
203-
"print_legend=True requires the decomposition object. Table skipped."
214+
"print_legend=True requires the decomposition parameter. Table skipped.",
215+
stacklevel=2,
204216
)
205217
else:
206218
try:
@@ -229,7 +241,7 @@ def two_output_visualization(
229241
ylim: tuple[float, float] | None = None,
230242
r_scatter: float = 1.0,
231243
print_legend: bool = False,
232-
decomposition=None,
244+
decomposition: DecompositionResult | None = None,
233245
) -> tuple[plt.Figure, np.ndarray]:
234246
"""Two-output visualization.
235247
@@ -261,7 +273,7 @@ def two_output_visualization(
261273
Fraction of data points shown in the scatter plot.
262274
print_legend: Boolean, optional
263275
Prints plot legend.
264-
decomposition: Object, optional
276+
decomposition: DecompositionResult, optional
265277
Required for print_legend.
266278
267279
Returns
@@ -322,13 +334,16 @@ def two_output_visualization(
322334
fig.subplots_adjust(wspace=-0.015, hspace=0)
323335

324336
if print_legend:
325-
from IPython.display import display
326-
327-
if decomposition is None:
328-
import warnings
329-
337+
if not HAS_IPYTHON:
338+
warnings.warn(
339+
"print_legend=True requires ipython to be installed. "
340+
"Install it with: pip install simdec[ipython]",
341+
stacklevel=2,
342+
)
343+
elif decomposition is None:
330344
warnings.warn(
331-
"print_legend=True requires the decomposition object. Table skipped."
345+
"print_legend=True requires the decomposition parameter. Table skipped.",
346+
stacklevel=2,
332347
)
333348
else:
334349
try:

0 commit comments

Comments
 (0)