Skip to content

Commit 8e4beda

Browse files
committed
Add two-output graph for visualization.py. Add csv validation for dashboard to catch wrong delimiters or incorrect column names. Closes #45. Closes #50.
1 parent 5cedad3 commit 8e4beda

File tree

3 files changed

+167
-11
lines changed

3 files changed

+167
-11
lines changed

panel/simdec_app.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import bisect
22
import io
3+
import re
34

45
from bokeh.models import PrintfTickFormatter
56
from bokeh.models.widgets.tables import NumberFormatter
@@ -15,10 +16,8 @@
1516
from simdec.sensitivity_indices import SensitivityAnalysisResult
1617
from simdec.visualization import sequential_cmaps, single_color_to_colormap
1718

18-
1919
# panel app
20-
pn.extension("tabulator")
21-
pn.extension("floatpanel")
20+
pn.extension("tabulator", "floatpanel", notifications=True)
2221

2322
pn.config.sizing_mode = "stretch_width"
2423
pn.config.throttled = True
@@ -43,34 +42,82 @@
4342
)
4443

4544

45+
def _validate_csv_bytes(raw_bytes):
46+
"""Pre-parse validation. Returns an error string or None."""
47+
try:
48+
first_line = raw_bytes.decode("utf-8").split("\n")[0].strip()
49+
except UnicodeDecodeError:
50+
return "File encoding error. Please use files in UTF-8."
51+
52+
if "," not in first_line:
53+
detected = (
54+
"Semicolons(';')"
55+
if ";" in first_line
56+
else "tabs"
57+
if "\t" in first_line
58+
else "Unknown delimiter"
59+
)
60+
return f"Wrong column delimiter {detected}. Save the data with commas ',' as the delimiter"
61+
62+
col_names = [c.strip().strip('"').strip("'") for c in first_line.split(",")]
63+
bad_cols = [c for c in col_names if re.search(r"[^A-Za-z0-9_ \-.]", c)]
64+
if bad_cols:
65+
return (
66+
f"Special characters found in column name(s): {bad_cols}."
67+
f"Column names may contain only letters, numbers and underscores."
68+
f"Please rename columns {bad_cols} before uploading data again."
69+
)
70+
return None
71+
72+
4673
@pn.cache
4774
def load_data(text_fname):
4875
if text_fname is None:
49-
text_fname = "tests/data/stress.csv"
50-
else:
51-
text_fname = io.BytesIO(text_fname)
76+
return pd.read_csv("tests/data/stress.csv")
77+
78+
raw_bytes = bytes(text_fname)
5279

53-
data = pd.read_csv(text_fname)
54-
return data
80+
# Run pre-validation
81+
error = _validate_csv_bytes(raw_bytes)
82+
if error:
83+
pn.state.notifications.error(error, duration=0)
84+
return None
85+
86+
# Try parsing
87+
try:
88+
text_fname = io.BytesIO(text_fname)
89+
return pd.read_csv(text_fname)
90+
except Exception as e:
91+
pn.state.notifications.error(f"Could not parse CSV {e}.", duration=0)
92+
return None
5593

5694

5795
@pn.cache
5896
def column_inputs(data, output):
97+
if data is None:
98+
return []
5999
inputs = list(data.columns)
60-
inputs.remove(output)
100+
if output in inputs:
101+
inputs.remove(output)
61102
return inputs
62103

63104

64105
@pn.cache
65106
def column_output(data):
107+
if data is None:
108+
return []
66109
return list(data.columns)
67110

68111

69112
@pn.cache
70113
def filtered_data(data, output_name):
114+
if data is None or not output_name:
115+
return pd.Series(dtype=float)
71116
try:
72117
return data[output_name]
73118
except KeyError:
119+
if isinstance(output_name, list):
120+
return data.iloc[:, [0]]
74121
return data.iloc[:, 0]
75122

76123

@@ -350,7 +397,7 @@ def csv_data(
350397

351398
interactive_column_output = pn.bind(column_output, interactive_file)
352399
# hack to make the default selection faster
353-
interactive_output_ = pn.bind(lambda x: x[0], interactive_column_output)
400+
interactive_output_ = pn.bind(lambda x: x[0] if x else None, interactive_column_output)
354401
selector_output = pn.widgets.Select(
355402
name="Output", value=interactive_output_, options=interactive_column_output
356403
)

src/simdec/visualization.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import functools
33
import itertools
4-
from typing import Literal
4+
from typing import Literal, Optional
55

66
import colorsys
77
import matplotlib as mpl
@@ -135,17 +135,25 @@ def palette(
135135
def visualization(
136136
*,
137137
bins: pd.DataFrame,
138+
bins2: Optional[pd.DataFrame] = None,
138139
palette: list[list[float]],
139140
n_bins: str | int = "auto",
140141
kind: Literal["histogram", "boxplot"] = "histogram",
141142
ax=None,
143+
output_name: str = "Output 1",
144+
output_name2: str = "Output 2",
145+
xlim: Optional[tuple[float, float]] = None,
146+
ylim: Optional[tuple[float, float]] = None,
147+
r_scatter: float = 1.0,
142148
) -> plt.Axes:
143149
"""Histogram plot of scenarios.
144150
145151
Parameters
146152
----------
147153
bins : DataFrame
148154
Multidimensional bins.
155+
bins2 : DataFrame
156+
Multidimensional bins for output 2
149157
palette : list of int of size (n, 4)
150158
List of colours corresponding to scenarios.
151159
n_bins : str or int
@@ -154,16 +162,85 @@ def visualization(
154162
Histogram or Box Plot.
155163
ax : Axes, optional
156164
Matplotlib axis.
165+
output_name : str, default "Output 1"
166+
Name of the primary output variable.
167+
output_name2 : str, default "Output 2"
168+
Name of the second output variable.
169+
xlim : tuple of float, optional
170+
Minimum and maximum values for the x-axis (Output 1).
171+
ylim : tuple of float, optional
172+
Minimum and maximum values for the y-axis (Output 2).
173+
r_scatter : float, default 1.0
174+
The portion of data points displayed on the scatter plot (0 to 1).
157175
158176
Returns
159177
-------
178+
axs : Axes
179+
Matplotlib axis for two-output graph.
160180
ax : Axes
161181
Matplotlib axis.
162182
163183
"""
164184
# needed to get the correct stacking order
165185
bins.columns = pd.RangeIndex(start=len(bins.columns), stop=0, step=-1)
166186

187+
if bins2 is not None:
188+
fig, axs = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(8, 8))
189+
axs[0, 1].axis("off")
190+
191+
sns.histplot(
192+
bins,
193+
multiple="stack",
194+
stat="probability",
195+
palette=palette,
196+
common_bins=True,
197+
common_norm=True,
198+
bins=n_bins,
199+
legend=False,
200+
ax=axs[0, 0],
201+
)
202+
axs[0, 0].set_xlim(xlim)
203+
axs[0, 0].set_box_aspect(1)
204+
axs[0, 0].axis("off")
205+
206+
data = pd.concat([pd.melt(bins), pd.melt(bins2)["value"]], axis=1)
207+
data.columns = ["c", "x", "y"]
208+
209+
if r_scatter < 1.0:
210+
data = data.sample(frac=r_scatter)
211+
212+
sns.scatterplot(
213+
data=data,
214+
x="x",
215+
y="y",
216+
hue="c",
217+
palette=palette,
218+
ax=axs[1, 0],
219+
legend=False,
220+
)
221+
axs[1, 0].set(xlabel=output_name, ylabel=output_name2)
222+
axs[1, 0].set_box_aspect(1)
223+
224+
sns.histplot(
225+
data,
226+
y="y",
227+
hue="c",
228+
multiple="stack",
229+
stat="probability",
230+
palette=palette,
231+
common_bins=True,
232+
common_norm=True,
233+
bins=40,
234+
legend=False,
235+
ax=axs[1, 1],
236+
)
237+
axs[1, 1].set_ylim(ylim)
238+
axs[1, 1].set_box_aspect(1)
239+
axs[1, 1].axis("off")
240+
241+
fig.subplots_adjust(wspace=-0.015, hspace=0)
242+
return axs[1, 0]
243+
167244
if kind == "histogram":
168245
ax = sns.histplot(
169246
bins,

tests/test_visualization.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
import pandas as pd
3+
import matplotlib.pyplot as plt
4+
import simdec as sd
5+
6+
7+
def test_visualization_single_output():
8+
bins = pd.DataFrame({"s1": [1, 2], "s2": [3, 4]})
9+
palette = [[1, 0, 0, 1], [0, 1, 0, 1]]
10+
11+
ax = sd.visualization(bins=bins, palette=palette, kind="histogram")
12+
assert isinstance(ax, plt.Axes)
13+
14+
ax_box = sd.visualization(bins=bins, palette=palette, kind="boxplot")
15+
assert isinstance(ax_box, plt.Axes)
16+
17+
18+
def test_visualization_two_outputs():
19+
bins = pd.DataFrame({"s1": [1, 2]})
20+
bins2 = pd.DataFrame({"s1": [5, 6]})
21+
palette = [[1, 0, 0, 1]]
22+
23+
ax = sd.visualization(bins=bins, bins2=bins2, palette=palette)
24+
25+
assert ax.get_xlabel() == "Output 1"
26+
assert len(ax.figure.axes) == 4
27+
28+
29+
def test_visualization_invalid_kind():
30+
bins = pd.DataFrame({"s1": [1]})
31+
with pytest.raises(ValueError, match="'kind' can only be 'histogram' or 'boxplot'"):
32+
sd.visualization(bins=bins, palette=[[1, 0, 0, 1]], kind="invalid")

0 commit comments

Comments
 (0)