diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 3f276beb..1a72a5cb 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -772,7 +772,8 @@ def _render_points( ): points = points[coords].compute() else: - coords += [col_for_color] + if col_for_color not in coords: + coords.append(col_for_color) points = points[coords].compute() added_color_from_table = False diff --git a/tests/pl/test_render_points.py b/tests/pl/test_render_points.py index 497fddd5..67dd71c9 100644 --- a/tests/pl/test_render_points.py +++ b/tests/pl/test_render_points.py @@ -1040,6 +1040,18 @@ def test_render_points_color_by_z_with_extra_columns(): plt.close(fig) +@pytest.mark.parametrize("color", ["x", "y"]) +def test_render_points_color_by_coord_axis(color): + # regression test for #613 + pts = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]})) + sdata = SpatialData(points={"p": pts}) + fig, ax = plt.subplots() + try: + sdata.pl.render_points("p", color=color).pl.show(ax=ax) + finally: + plt.close(fig) + + def test_render_points_disjoint_instance_ids_clear_error(): # regression test for #603: disjoint instance_id values must raise a clear ValueError points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))