Skip to content

Commit 59da170

Browse files
authored
Extend make_palette_from_data to support labels (#670)
1 parent 2329522 commit 59da170

2 files changed

Lines changed: 26 additions & 8 deletions

File tree

src/spatialdata_plot/pl/_palette.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def _resolve_element(
284284
"""Extract categorical labels from a SpatialData element.
285285
286286
Labels come from a column on the element itself, or from a linked
287-
table (joined on the instance key to guarantee alignment).
287+
table (joined on the instance key to guarantee alignment). Raster
288+
(labels) elements have no per-pixel column, so the value always
289+
comes from the linked table.
288290
"""
289291
if element in sdata.shapes:
290292
gdf = sdata.shapes[element]
@@ -298,11 +300,13 @@ def _resolve_element(
298300
labels_series = ddf[[color]].compute()[color]
299301
else:
300302
labels_series = _get_labels_from_table(sdata, element, color, table_name)
303+
elif element in sdata.labels:
304+
labels_series = _get_labels_from_table(sdata, element, color, table_name)
301305
else:
302-
available = list(sdata.shapes.keys()) + list(sdata.points.keys())
306+
available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + list(sdata.labels.keys())
303307
raise KeyError(
304-
f"Element '{element}' not found in sdata.shapes or sdata.points. "
305-
f"Available elements: {available}. Note: labels (raster) elements are not yet supported."
308+
f"Element '{element}' not found in sdata.shapes, sdata.points, or sdata.labels. "
309+
f"Available elements: {available}."
306310
)
307311

308312
is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype)
@@ -471,11 +475,12 @@ def make_palette_from_data(
471475
sdata
472476
A :class:`spatialdata.SpatialData` object.
473477
element
474-
Name of a shapes or points element in *sdata*.
478+
Name of a shapes, points, or labels element in *sdata*.
475479
color
476-
Column name containing categorical labels. The column is first
477-
looked up directly on the element (both for shapes and points);
478-
if not found there, it falls back to the linked AnnData table.
480+
Column name containing categorical labels. For shapes and points
481+
the column is first looked up directly on the element and falls
482+
back to the linked AnnData table. For labels the column is
483+
always read from the linked table.
479484
palette
480485
Source colours. Accepts the same values as
481486
:func:`make_palette` (*None*, a list, a named palette, or a

tests/pl/test_palette.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,15 @@ def test_shapes_with_table(self, shapes_sdata: SpatialData):
224224
assert isinstance(result, dict)
225225
assert set(result.keys()) == {"X", "Y", "Z"}
226226

227+
def test_labels_with_table(self, sdata_blobs: SpatialData):
228+
# Regression test for #662
229+
n = sdata_blobs["table"].n_obs
230+
sdata_blobs["table"].obs["cell_type"] = pd.Categorical(np.random.default_rng(0).choice(["X", "Y", "Z"], size=n))
231+
result = make_palette_from_data(sdata_blobs, "blobs_labels", "cell_type", seed=42)
232+
assert isinstance(result, dict)
233+
assert set(result.keys()) == {"X", "Y", "Z"}
234+
assert all(v.startswith("#") for v in result.values())
235+
227236

228237
# ---------------------------------------------------------------------------
229238
# Error cases
@@ -243,6 +252,10 @@ def test_missing_column(self, clustered_sdata: SpatialData):
243252
with pytest.raises(KeyError, match="not found"):
244253
make_palette_from_data(clustered_sdata, "cells", "nonexistent_col")
245254

255+
def test_missing_column_on_labels(self, sdata_blobs: SpatialData):
256+
with pytest.raises(KeyError, match="not found"):
257+
make_palette_from_data(sdata_blobs, "blobs_labels", "nonexistent_col")
258+
246259
def test_unknown_method(self, clustered_sdata: SpatialData):
247260
with pytest.raises(ValueError, match="Unknown method"):
248261
make_palette_from_data(clustered_sdata, "cells", "cell_type", method="invalid") # type: ignore[arg-type]

0 commit comments

Comments
 (0)