Skip to content

Commit d90dd4d

Browse files
committed
Add two-output visualization and robust CSV parsing
Created a new function for two-output visualization. Added a try/except block for CSV parsing in the dashboard. If a wrong delimiter or invalid character is detected, it loads stress.csv to stop buffering and prevent reactive cascade crashes. Closes #45 Closes #50
1 parent 8e4beda commit d90dd4d

File tree

4 files changed

+163
-133
lines changed

4 files changed

+163
-133
lines changed

panel/simdec_app.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -41,55 +41,30 @@
4141
# save_layout=True,
4242
)
4343

44-
45-
def _validate_csv_bytes(raw_bytes):
46-
"""Pre-parse validation. Returns an error string or None."""
47-
try:
48-
first_line = raw_bytes.decode("utf-8").split("\n")[0].strip()
49-
except UnicodeDecodeError:
50-
return "File encoding error. Please use files in UTF-8."
51-
52-
if "," not in first_line:
53-
detected = (
54-
"Semicolons(';')"
55-
if ";" in first_line
56-
else "tabs"
57-
if "\t" in first_line
58-
else "Unknown delimiter"
59-
)
60-
return f"Wrong column delimiter {detected}. Save the data with commas ',' as the delimiter"
61-
62-
col_names = [c.strip().strip('"').strip("'") for c in first_line.split(",")]
63-
bad_cols = [c for c in col_names if re.search(r"[^A-Za-z0-9_ \-.]", c)]
64-
if bad_cols:
65-
return (
66-
f"Special characters found in column name(s): {bad_cols}."
67-
f"Column names may contain only letters, numbers and underscores."
68-
f"Please rename columns {bad_cols} before uploading data again."
69-
)
70-
return None
44+
VALID_CHARACTERS = re.compile(r"[A-Za-z0-9_ \-.]")
45+
GENERIC_ERROR_MSG = (
46+
"Could not parse the CSV file. "
47+
"Please check that it uses commas ',' as the delimiter "
48+
"and that column names contain no special characters."
49+
)
7150

7251

7352
@pn.cache
7453
def load_data(text_fname):
7554
if text_fname is None:
7655
return pd.read_csv("tests/data/stress.csv")
77-
78-
raw_bytes = bytes(text_fname)
79-
80-
# Run pre-validation
81-
error = _validate_csv_bytes(raw_bytes)
82-
if error:
83-
pn.state.notifications.error(error, duration=0)
84-
return None
85-
86-
# Try parsing
8756
try:
88-
text_fname = io.BytesIO(text_fname)
89-
return pd.read_csv(text_fname)
90-
except Exception as e:
91-
pn.state.notifications.error(f"Could not parse CSV {e}.", duration=0)
92-
return None
57+
raw = bytes(text_fname)
58+
first_line = raw.decode("utf-8").split("\n")[0].strip()
59+
if "," not in first_line:
60+
raise ValueError("No comma delimiter")
61+
col_names = [c.strip().strip('"').strip("'") for c in first_line.split(",")]
62+
if any(VALID_CHARACTERS.search(c) for c in col_names):
63+
raise ValueError("Bad column names")
64+
return pd.read_csv(io.BytesIO(raw))
65+
except Exception:
66+
pn.state.notifications.error(GENERIC_ERROR_MSG, duration=0)
67+
return pd.read_csv("tests/data/stress.csv")
9368

9469

9570
@pn.cache

src/simdec/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"states_expansion",
99
"decomposition",
1010
"visualization",
11+
"two_output_visualization",
1112
"tableau",
1213
"palette",
1314
]

src/simdec/visualization.py

Lines changed: 101 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import functools
33
import itertools
4-
from typing import Literal, Optional
4+
from typing import Literal
55

66
import colorsys
77
import matplotlib as mpl
@@ -11,7 +11,7 @@
1111
import pandas as pd
1212
from pandas.io.formats.style import Styler
1313

14-
__all__ = ["visualization", "tableau", "palette"]
14+
__all__ = ["visualization", "two_output_visualization", "tableau", "palette"]
1515

1616

1717
SEQUENTIAL_PALETTES = [
@@ -135,25 +135,17 @@ def palette(
135135
def visualization(
136136
*,
137137
bins: pd.DataFrame,
138-
bins2: Optional[pd.DataFrame] = None,
139138
palette: list[list[float]],
140139
n_bins: str | int = "auto",
141140
kind: Literal["histogram", "boxplot"] = "histogram",
142141
ax=None,
143-
output_name: str = "Output 1",
144-
output_name2: str = "Output 2",
145-
xlim: Optional[tuple[float, float]] = None,
146-
ylim: Optional[tuple[float, float]] = None,
147-
r_scatter: float = 1.0,
148142
) -> plt.Axes:
149143
"""Histogram plot of scenarios.
150144
151145
Parameters
152146
----------
153147
bins : DataFrame
154148
Multidimensional bins.
155-
bins2 : DataFrame
156-
Multidimensional bins for output 2
157149
palette : list of int of size (n, 4)
158150
List of colours corresponding to scenarios.
159151
n_bins : str or int
@@ -162,85 +154,16 @@ def visualization(
162154
Histogram or Box Plot.
163155
ax : Axes, optional
164156
Matplotlib axis.
165-
output_name : str, default "Output 1"
166-
Name of the primary output variable.
167-
output_name2 : str, default "Output 2"
168-
Name of the second output variable.
169-
xlim : tuple of float, optional
170-
Minimum and maximum values for the x-axis (Output 1).
171-
ylim : tuple of float, optional
172-
Minimum and maximum values for the y-axis (Output 2).
173-
r_scatter : float, default 1.0
174-
The portion of data points displayed on the scatter plot (0 to 1).
175157
176158
Returns
177159
-------
178-
axs : Axes
179-
Matplotlib axis for two-output graph.
180160
ax : Axes
181161
Matplotlib axis.
182162
183163
"""
184164
# needed to get the correct stacking order
185165
bins.columns = pd.RangeIndex(start=len(bins.columns), stop=0, step=-1)
186166

187-
if bins2 is not None:
188-
fig, axs = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(8, 8))
189-
axs[0, 1].axis("off")
190-
191-
sns.histplot(
192-
bins,
193-
multiple="stack",
194-
stat="probability",
195-
palette=palette,
196-
common_bins=True,
197-
common_norm=True,
198-
bins=n_bins,
199-
legend=False,
200-
ax=axs[0, 0],
201-
)
202-
axs[0, 0].set_xlim(xlim)
203-
axs[0, 0].set_box_aspect(1)
204-
axs[0, 0].axis("off")
205-
206-
data = pd.concat([pd.melt(bins), pd.melt(bins2)["value"]], axis=1)
207-
data.columns = ["c", "x", "y"]
208-
209-
if r_scatter < 1.0:
210-
data = data.sample(frac=r_scatter)
211-
212-
sns.scatterplot(
213-
data=data,
214-
x="x",
215-
y="y",
216-
hue="c",
217-
palette=palette,
218-
ax=axs[1, 0],
219-
legend=False,
220-
)
221-
axs[1, 0].set(xlabel=output_name, ylabel=output_name2)
222-
axs[1, 0].set_box_aspect(1)
223-
224-
sns.histplot(
225-
data,
226-
y="y",
227-
hue="c",
228-
multiple="stack",
229-
stat="probability",
230-
palette=palette,
231-
common_bins=True,
232-
common_norm=True,
233-
bins=40,
234-
legend=False,
235-
ax=axs[1, 1],
236-
)
237-
axs[1, 1].set_ylim(ylim)
238-
axs[1, 1].set_box_aspect(1)
239-
axs[1, 1].axis("off")
240-
241-
fig.subplots_adjust(wspace=-0.015, hspace=0)
242-
return axs[1, 0]
243-
244167
if kind == "histogram":
245168
ax = sns.histplot(
246169
bins,
@@ -266,6 +189,105 @@ def visualization(
266189
return ax
267190

268191

192+
def two_output_visualization(
193+
*,
194+
bins: pd.DataFrame,
195+
bins2: pd.DataFrame,
196+
palette: list[list[float]],
197+
n_bins: str | int = "auto",
198+
output_name: str = "Output 1",
199+
output_name2: str = "Output 2",
200+
xlim: tuple[float, float] | None = None,
201+
ylim: tuple[float, float] | None = None,
202+
r_scatter: float = 1.0,
203+
) -> tuple[plt.Figure, np.ndarray]:
204+
"""Two-output visualization.
205+
Produces a 2x2 figure
206+
* top-left : stacked histogram for *output 1* (axes hidden)
207+
* bottom-left : scatter of output 1 vs output 2, coloured by scenario
208+
* bottom-right: rotated stacked histogram for *output 2* (axes hidden)
209+
* top-right : empty
210+
211+
Parameters
212+
----------
213+
bins : DataFrame
214+
Multidimensional bins for the primary output.
215+
bins2 : DataFrame
216+
Multidimensional bins for the secondary output.
217+
palette : list of int of size (n, 4)
218+
List of colours corresponding to scenarios.
219+
n_bins : str or int
220+
Number of bins for the histograms.
221+
output_name : str, default "Output 1"
222+
Axis label for the primary output.
223+
output_name2 : str, default "Output 2"
224+
Axis label for the secondary output.
225+
xlim : tuple of float, optional
226+
Limits for the primary output axis (scatter x / top histogram).
227+
ylim : tuple of float, optional
228+
Limits for the secondary output axis (scatter y / right histogram).
229+
r_scatter : float, default 1.0
230+
Fraction of data points shown in the scatter plot.
231+
232+
Returns
233+
-------
234+
fig : Figure
235+
axs : ndarray of shape (2, 2)
236+
237+
"""
238+
fig, axs = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(8, 8))
239+
240+
axs[0, 1].axis("off")
241+
242+
visualization(bins=bins.copy(), palette=palette, n_bins=n_bins, ax=axs[0, 0])
243+
if xlim is not None:
244+
axs[0, 0].set_xlim(xlim)
245+
axs[0, 0].set_box_aspect(1)
246+
axs[0, 0].axis("off")
247+
248+
data = pd.concat([pd.melt(bins), pd.melt(bins2)["value"]], axis=1)
249+
data.columns = ["c", "x", "y"]
250+
if r_scatter < 1.0:
251+
data = data.sample(frac=r_scatter)
252+
253+
sns.scatterplot(
254+
data=data,
255+
x="x",
256+
y="y",
257+
hue="c",
258+
palette=palette,
259+
ax=axs[1, 0],
260+
legend=False,
261+
)
262+
axs[1, 0].set(xlabel=output_name, ylabel=output_name2)
263+
if xlim is not None:
264+
axs[1, 0].set_xlim(xlim)
265+
if ylim is not None:
266+
axs[1, 0].set_ylim(ylim)
267+
axs[1, 0].set_box_aspect(1)
268+
269+
sns.histplot(
270+
data,
271+
y="y",
272+
hue="c",
273+
multiple="stack",
274+
stat="probability",
275+
palette=palette,
276+
common_bins=True,
277+
common_norm=True,
278+
bins=40,
279+
legend=False,
280+
ax=axs[1, 1],
281+
)
282+
if ylim is not None:
283+
axs[1, 1].set_ylim(ylim)
284+
axs[1, 1].set_box_aspect(1)
285+
axs[1, 1].axis("off")
286+
287+
fig.subplots_adjust(wspace=-0.015, hspace=0)
288+
return fig, axs
289+
290+
269291
def tableau(
270292
*,
271293
var_names: list[str],

tests/test_visualization.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,61 @@
44
import simdec as sd
55

66

7-
def test_visualization_single_output():
7+
@pytest.fixture(autouse=True)
8+
def close_plots():
9+
yield
10+
plt.close("all")
11+
12+
13+
def test_visualization_histogram():
814
bins = pd.DataFrame({"s1": [1, 2], "s2": [3, 4]})
915
palette = [[1, 0, 0, 1], [0, 1, 0, 1]]
10-
1116
ax = sd.visualization(bins=bins, palette=palette, kind="histogram")
1217
assert isinstance(ax, plt.Axes)
1318

14-
ax_box = sd.visualization(bins=bins, palette=palette, kind="boxplot")
15-
assert isinstance(ax_box, plt.Axes)
1619

20+
def test_visualization_boxplot():
21+
bins = pd.DataFrame({"s1": [1, 2], "s2": [3, 4]})
22+
palette = [[1, 0, 0, 1], [0, 1, 0, 1]]
23+
ax = sd.visualization(bins=bins, palette=palette, kind="boxplot")
24+
assert isinstance(ax, plt.Axes)
1725

18-
def test_visualization_two_outputs():
26+
27+
def test_visualization_invalid_kind():
28+
bins = pd.DataFrame({"s1": [1]})
29+
with pytest.raises(ValueError, match="'kind' can only be 'histogram' or 'boxplot'"):
30+
sd.visualization(bins=bins, palette=[[1, 0, 0, 1]], kind="invalid")
31+
32+
33+
def test_two_output_visualization_returns_correct_types():
1934
bins = pd.DataFrame({"s1": [1, 2]})
2035
bins2 = pd.DataFrame({"s1": [5, 6]})
2136
palette = [[1, 0, 0, 1]]
37+
fig, axs = sd.two_output_visualization(bins=bins, bins2=bins2, palette=palette)
38+
assert isinstance(fig, plt.Figure)
39+
assert axs.shape == (2, 2)
2240

23-
ax = sd.visualization(bins=bins, bins2=bins2, palette=palette)
2441

25-
assert ax.get_xlabel() == "Output 1"
26-
assert len(ax.figure.axes) == 4
42+
def test_two_output_visualization_axis_labels():
43+
bins = pd.DataFrame({"s1": [1, 2]})
44+
bins2 = pd.DataFrame({"s1": [5, 6]})
45+
palette = [[1, 0, 0, 1]]
46+
_, axs = sd.two_output_visualization(
47+
bins=bins,
48+
bins2=bins2,
49+
palette=palette,
50+
output_name="Stress",
51+
output_name2="Displacement",
52+
)
53+
assert axs[1, 0].get_xlabel() == "Stress"
54+
assert axs[1, 0].get_ylabel() == "Displacement"
2755

2856

29-
def test_visualization_invalid_kind():
30-
bins = pd.DataFrame({"s1": [1]})
31-
with pytest.raises(ValueError, match="'kind' can only be 'histogram' or 'boxplot'"):
32-
sd.visualization(bins=bins, palette=[[1, 0, 0, 1]], kind="invalid")
57+
def test_two_output_visualization_r_scatter():
58+
bins = pd.DataFrame({"s1": list(range(100))})
59+
bins2 = pd.DataFrame({"s1": list(range(100))})
60+
palette = [[1, 0, 0, 1]]
61+
fig, axs = sd.two_output_visualization(
62+
bins=bins, bins2=bins2, palette=palette, r_scatter=0.5
63+
)
64+
assert isinstance(fig, plt.Figure)

0 commit comments

Comments
 (0)