From 582899251f73ef84bd08c2aa7fb3d2bf79422ba6 Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 19 May 2026 19:53:00 +0200 Subject: [PATCH 1/2] Extend make_palette_from_data to support labels (#662) Labels (raster) elements have no per-pixel column, so the categorical column always lives on the linked AnnData table. Dispatch through _get_labels_from_table for the labels case; broaden the KeyError to list labels in the available-elements set. --- src/spatialdata_plot/pl/_palette.py | 21 ++++++++------ tests/pl/test_palette.py | 44 ++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/src/spatialdata_plot/pl/_palette.py b/src/spatialdata_plot/pl/_palette.py index 4159de3a..83d1b01b 100644 --- a/src/spatialdata_plot/pl/_palette.py +++ b/src/spatialdata_plot/pl/_palette.py @@ -284,7 +284,9 @@ def _resolve_element( """Extract categorical labels from a SpatialData element. Labels come from a column on the element itself, or from a linked - table (joined on the instance key to guarantee alignment). + table (joined on the instance key to guarantee alignment). Raster + (labels) elements have no per-pixel column, so the value always + comes from the linked table. """ if element in sdata.shapes: gdf = sdata.shapes[element] @@ -298,11 +300,13 @@ def _resolve_element( labels_series = ddf[[color]].compute()[color] else: labels_series = _get_labels_from_table(sdata, element, color, table_name) + elif element in sdata.labels: + labels_series = _get_labels_from_table(sdata, element, color, table_name) else: - available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + available = list(sdata.shapes.keys()) + list(sdata.points.keys()) + list(sdata.labels.keys()) raise KeyError( - f"Element '{element}' not found in sdata.shapes or sdata.points. " - f"Available elements: {available}. Note: labels (raster) elements are not yet supported." + f"Element '{element}' not found in sdata.shapes, sdata.points, or sdata.labels. " + f"Available elements: {available}." ) is_categorical = isinstance(getattr(labels_series, "dtype", None), pd.CategoricalDtype) @@ -471,11 +475,12 @@ def make_palette_from_data( sdata A :class:`spatialdata.SpatialData` object. element - Name of a shapes or points element in *sdata*. + Name of a shapes, points, or labels element in *sdata*. color - Column name containing categorical labels. The column is first - looked up directly on the element (both for shapes and points); - if not found there, it falls back to the linked AnnData table. + Column name containing categorical labels. For shapes and points + the column is first looked up directly on the element and falls + back to the linked AnnData table. For labels the column is + always read from the linked table. palette Source colours. Accepts the same values as :func:`make_palette` (*None*, a list, a named palette, or a diff --git a/tests/pl/test_palette.py b/tests/pl/test_palette.py index 311209d0..82b1d71b 100644 --- a/tests/pl/test_palette.py +++ b/tests/pl/test_palette.py @@ -9,7 +9,7 @@ import scanpy as sc from matplotlib.colors import to_hex, to_rgb from spatialdata import SpatialData -from spatialdata.models import PointsModel, ShapesModel, TableModel +from spatialdata.models import Labels2DModel, PointsModel, ShapesModel, TableModel import spatialdata_plot # noqa: F401 — registers accessor from spatialdata_plot.pl._palette import ( @@ -67,6 +67,32 @@ def _build_shapes_sdata(seed: int = 0) -> SpatialData: return SpatialData(shapes={"my_shapes": ShapesModel.parse(gdf)}, tables={"table": adata}) +def _build_labels_sdata(seed: int = 0) -> SpatialData: + """SpatialData with a labels element + linked table containing categorical labels.""" + from anndata import AnnData + + rng = np.random.default_rng(seed) + n = 10 + arr = np.zeros((20, 20), dtype=np.int32) + for i in range(n): + y, x = divmod(i, 5) + arr[y * 4 : y * 4 + 3, x * 4 : x * 4 + 3] = i + 1 + + adata = AnnData( + np.zeros((n, 1)), + obs=pd.DataFrame( + { + "cell_type": pd.Categorical(rng.choice(["X", "Y", "Z"], size=n)), + "instance_id": np.arange(1, n + 1), + "region": ["my_labels"] * n, + }, + index=pd.RangeIndex(n).astype(str), + ), + ) + adata = TableModel.parse(adata=adata, region="my_labels", region_key="region", instance_key="instance_id") + return SpatialData(labels={"my_labels": Labels2DModel.parse(arr)}, tables={"table": adata}) + + @pytest.fixture(scope="module") def clustered_sdata() -> SpatialData: return _build_clustered_points_sdata() @@ -77,6 +103,11 @@ def shapes_sdata() -> SpatialData: return _build_shapes_sdata() +@pytest.fixture(scope="module") +def labels_sdata() -> SpatialData: + return _build_labels_sdata() + + # --------------------------------------------------------------------------- # Unit tests: internals # --------------------------------------------------------------------------- @@ -224,6 +255,13 @@ def test_shapes_with_table(self, shapes_sdata: SpatialData): assert isinstance(result, dict) assert set(result.keys()) == {"X", "Y", "Z"} + def test_labels_with_table(self, labels_sdata: SpatialData): + # Regression test for #662 + result = make_palette_from_data(labels_sdata, "my_labels", "cell_type", seed=42) + assert isinstance(result, dict) + assert set(result.keys()) == {"X", "Y", "Z"} + assert all(v.startswith("#") for v in result.values()) + # --------------------------------------------------------------------------- # Error cases @@ -243,6 +281,10 @@ def test_missing_column(self, clustered_sdata: SpatialData): with pytest.raises(KeyError, match="not found"): make_palette_from_data(clustered_sdata, "cells", "nonexistent_col") + def test_missing_column_on_labels(self, labels_sdata: SpatialData): + with pytest.raises(KeyError, match="not found"): + make_palette_from_data(labels_sdata, "my_labels", "nonexistent_col") + def test_unknown_method(self, clustered_sdata: SpatialData): with pytest.raises(ValueError, match="Unknown method"): make_palette_from_data(clustered_sdata, "cells", "cell_type", method="invalid") # type: ignore[arg-type] From 659640e7b9b8c2d204fe37dd9e06c4d2462287ce Mon Sep 17 00:00:00 2001 From: anon Date: Tue, 19 May 2026 19:59:07 +0200 Subject: [PATCH 2/2] Reuse sdata_blobs fixture for labels tests Drop hand-rolled _build_labels_sdata / labels_sdata fixture. sdata_blobs already provides blobs_labels with a linked table; add the categorical cell_type column on the table directly, matching the existing pattern in test_render_labels.py. --- tests/pl/test_palette.py | 43 +++++++--------------------------------- 1 file changed, 7 insertions(+), 36 deletions(-) diff --git a/tests/pl/test_palette.py b/tests/pl/test_palette.py index 82b1d71b..5177ac75 100644 --- a/tests/pl/test_palette.py +++ b/tests/pl/test_palette.py @@ -9,7 +9,7 @@ import scanpy as sc from matplotlib.colors import to_hex, to_rgb from spatialdata import SpatialData -from spatialdata.models import Labels2DModel, PointsModel, ShapesModel, TableModel +from spatialdata.models import PointsModel, ShapesModel, TableModel import spatialdata_plot # noqa: F401 — registers accessor from spatialdata_plot.pl._palette import ( @@ -67,32 +67,6 @@ def _build_shapes_sdata(seed: int = 0) -> SpatialData: return SpatialData(shapes={"my_shapes": ShapesModel.parse(gdf)}, tables={"table": adata}) -def _build_labels_sdata(seed: int = 0) -> SpatialData: - """SpatialData with a labels element + linked table containing categorical labels.""" - from anndata import AnnData - - rng = np.random.default_rng(seed) - n = 10 - arr = np.zeros((20, 20), dtype=np.int32) - for i in range(n): - y, x = divmod(i, 5) - arr[y * 4 : y * 4 + 3, x * 4 : x * 4 + 3] = i + 1 - - adata = AnnData( - np.zeros((n, 1)), - obs=pd.DataFrame( - { - "cell_type": pd.Categorical(rng.choice(["X", "Y", "Z"], size=n)), - "instance_id": np.arange(1, n + 1), - "region": ["my_labels"] * n, - }, - index=pd.RangeIndex(n).astype(str), - ), - ) - adata = TableModel.parse(adata=adata, region="my_labels", region_key="region", instance_key="instance_id") - return SpatialData(labels={"my_labels": Labels2DModel.parse(arr)}, tables={"table": adata}) - - @pytest.fixture(scope="module") def clustered_sdata() -> SpatialData: return _build_clustered_points_sdata() @@ -103,11 +77,6 @@ def shapes_sdata() -> SpatialData: return _build_shapes_sdata() -@pytest.fixture(scope="module") -def labels_sdata() -> SpatialData: - return _build_labels_sdata() - - # --------------------------------------------------------------------------- # Unit tests: internals # --------------------------------------------------------------------------- @@ -255,9 +224,11 @@ def test_shapes_with_table(self, shapes_sdata: SpatialData): assert isinstance(result, dict) assert set(result.keys()) == {"X", "Y", "Z"} - def test_labels_with_table(self, labels_sdata: SpatialData): + def test_labels_with_table(self, sdata_blobs: SpatialData): # Regression test for #662 - result = make_palette_from_data(labels_sdata, "my_labels", "cell_type", seed=42) + n = sdata_blobs["table"].n_obs + sdata_blobs["table"].obs["cell_type"] = pd.Categorical(np.random.default_rng(0).choice(["X", "Y", "Z"], size=n)) + result = make_palette_from_data(sdata_blobs, "blobs_labels", "cell_type", seed=42) assert isinstance(result, dict) assert set(result.keys()) == {"X", "Y", "Z"} assert all(v.startswith("#") for v in result.values()) @@ -281,9 +252,9 @@ def test_missing_column(self, clustered_sdata: SpatialData): with pytest.raises(KeyError, match="not found"): make_palette_from_data(clustered_sdata, "cells", "nonexistent_col") - def test_missing_column_on_labels(self, labels_sdata: SpatialData): + def test_missing_column_on_labels(self, sdata_blobs: SpatialData): with pytest.raises(KeyError, match="not found"): - make_palette_from_data(labels_sdata, "my_labels", "nonexistent_col") + make_palette_from_data(sdata_blobs, "blobs_labels", "nonexistent_col") def test_unknown_method(self, clustered_sdata: SpatialData): with pytest.raises(ValueError, match="Unknown method"):