Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions src/spatialdata_plot/pl/_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def _resolve_element(
"""Extract categorical labels from a SpatialData element.

Labels come from a column on the element itself, or from a linked
table (joined on the instance key to guarantee alignment).
table (joined on the instance key to guarantee alignment). Raster
(labels) elements have no per-pixel column, so the value always
comes from the linked table.
"""
if element in sdata.shapes:
gdf = sdata.shapes[element]
Expand All @@ -298,11 +300,13 @@ def _resolve_element(
labels_series = ddf[[color]].compute()[color]
else:
labels_series = _get_labels_from_table(sdata, element, color, table_name)
elif element in sdata.labels:
labels_series = _get_labels_from_table(sdata, element, color, table_name)
else:
available = list(sdata.shapes.keys()) + list(sdata.points.keys())
available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + list(sdata.labels.keys())
raise KeyError(
f"Element '{element}' not found in sdata.shapes or sdata.points. "
f"Available elements: {available}. Note: labels (raster) elements are not yet supported."
f"Element '{element}' not found in sdata.shapes, sdata.points, or sdata.labels. "
f"Available elements: {available}."
)

is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype)
Expand Down Expand Up @@ -471,11 +475,12 @@ def make_palette_from_data(
sdata
A :class:`spatialdata.SpatialData` object.
element
Name of a shapes or points element in *sdata*.
Name of a shapes, points, or labels element in *sdata*.
color
Column name containing categorical labels. The column is first
looked up directly on the element (both for shapes and points);
if not found there, it falls back to the linked AnnData table.
Column name containing categorical labels. For shapes and points
the column is first looked up directly on the element and falls
back to the linked AnnData table. For labels the column is
always read from the linked table.
palette
Source colours. Accepts the same values as
:func:`make_palette` (*None*, a list, a named palette, or a
Expand Down
13 changes: 13 additions & 0 deletions tests/pl/test_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ def test_shapes_with_table(self, shapes_sdata: SpatialData):
assert isinstance(result, dict)
assert set(result.keys()) == {"X", "Y", "Z"}

def test_labels_with_table(self, sdata_blobs: SpatialData):
# Regression test for #662
n = sdata_blobs["table"].n_obs
sdata_blobs["table"].obs["cell_type"] = pd.Categorical(np.random.default_rng(0).choice(["X", "Y", "Z"], size=n))
result = make_palette_from_data(sdata_blobs, "blobs_labels", "cell_type", seed=42)
assert isinstance(result, dict)
assert set(result.keys()) == {"X", "Y", "Z"}
assert all(v.startswith("#") for v in result.values())


# ---------------------------------------------------------------------------
# Error cases
Expand All @@ -243,6 +252,10 @@ def test_missing_column(self, clustered_sdata: SpatialData):
with pytest.raises(KeyError, match="not found"):
make_palette_from_data(clustered_sdata, "cells", "nonexistent_col")

def test_missing_column_on_labels(self, sdata_blobs: SpatialData):
with pytest.raises(KeyError, match="not found"):
make_palette_from_data(sdata_blobs, "blobs_labels", "nonexistent_col")

def test_unknown_method(self, clustered_sdata: SpatialData):
with pytest.raises(ValueError, match="Unknown method"):
make_palette_from_data(clustered_sdata, "cells", "cell_type", method="invalid") # type: ignore[arg-type]
Expand Down
Loading