diff --git a/src/spatialdata_plot/pl/_palette.py b/src/spatialdata_plot/pl/_palette.py index 4159de3a..83d1b01b 100644 --- a/src/spatialdata_plot/pl/_palette.py +++ b/src/spatialdata_plot/pl/_palette.py @@ -284,7 +284,9 @@ def _resolve_element( """Extract categorical labels from a SpatialData element. Labels come from a column on the element itself, or from a linked - table (joined on the instance key to guarantee alignment). + table (joined on the instance key to guarantee alignment). Raster + (labels) elements have no per-pixel column, so the value always + comes from the linked table. """ if element in sdata.shapes: gdf = sdata.shapes[element] @@ -298,11 +300,13 @@ def _resolve_element( labels_series = ddf[[color]].compute()[color] else: labels_series = _get_labels_from_table(sdata, element, color, table_name) + elif element in sdata.labels: + labels_series = _get_labels_from_table(sdata, element, color, table_name) else: - available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + list(sdata.labels.keys()) raise KeyError( - f"Element '{element}' not found in sdata.shapes or sdata.points. " - f"Available elements: {available}. Note: labels (raster) elements are not yet supported." + f"Element '{element}' not found in sdata.shapes, sdata.points, or sdata.labels. " + f"Available elements: {available}." ) is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype) @@ -471,11 +475,12 @@ def make_palette_from_data( sdata A :class:`spatialdata.SpatialData` object. element - Name of a shapes or points element in *sdata*. + Name of a shapes, points, or labels element in *sdata*. color - Column name containing categorical labels. The column is first - looked up directly on the element (both for shapes and points); - if not found there, it falls back to the linked AnnData table. + Column name containing categorical labels. For shapes and points + the column is first looked up directly on the element and falls + back to the linked AnnData table. For labels the column is + always read from the linked table. palette Source colours. Accepts the same values as :func:`make_palette` (*None*, a list, a named palette, or a diff --git a/tests/pl/test_palette.py b/tests/pl/test_palette.py index 311209d0..5177ac75 100644 --- a/tests/pl/test_palette.py +++ b/tests/pl/test_palette.py @@ -224,6 +224,15 @@ def test_shapes_with_table(self, shapes_sdata: SpatialData): assert isinstance(result, dict) assert set(result.keys()) == {"X", "Y", "Z"} + def test_labels_with_table(self, sdata_blobs: SpatialData): + # Regression test for #662 + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["cell_type"] = pd.Categorical(np.random.default_rng(0).choice(["X", "Y", "Z"], size=n)) + result = make_palette_from_data(sdata_blobs, "blobs_labels", "cell_type", seed=42) + assert isinstance(result, dict) + assert set(result.keys()) == {"X", "Y", "Z"} + assert all(v.startswith("#") for v in result.values()) + # --------------------------------------------------------------------------- # Error cases @@ -243,6 +252,10 @@ def test_missing_column(self, clustered_sdata: SpatialData): with pytest.raises(KeyError, match="not found"): make_palette_from_data(clustered_sdata, "cells", "nonexistent_col") + def test_missing_column_on_labels(self, sdata_blobs: SpatialData): + with pytest.raises(KeyError, match="not found"): + make_palette_from_data(sdata_blobs, "blobs_labels", "nonexistent_col") + def test_unknown_method(self, clustered_sdata: SpatialData): with pytest.raises(ValueError, match="Unknown method"): make_palette_from_data(clustered_sdata, "cells", "cell_type", method="invalid") # type: ignore[arg-type]