Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,8 @@ def _render_points(
):
points = points[coords].compute()
else:
coords += [col_for_color]
if col_for_color not in coords:
coords.append(col_for_color)
points = points[coords].compute()

added_color_from_table = False
Expand Down
12 changes: 12 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,18 @@ def test_render_points_color_by_z_with_extra_columns():
plt.close(fig)


@pytest.mark.parametrize("color", ["x", "y"])
def test_render_points_color_by_coord_axis(color):
# regression test for #613
pts = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
sdata = SpatialData(points={"p": pts})
fig, ax = plt.subplots()
try:
sdata.pl.render_points("p", color=color).pl.show(ax=ax)
finally:
plt.close(fig)


def test_render_points_disjoint_instance_ids_clear_error():
# regression test for #603: disjoint instance_id values must raise a clear ValueError
points = PointsModel.parse(pd.DataFrame({"x": [1.0, 2.0, 3.0], "y": [1.0, 2.0, 3.0]}))
Expand Down
Loading