Skip to content

Commit ce4aeae

Browse files
committed
color bar logic
1 parent 8d7aa1c commit ce4aeae

File tree

5 files changed

+563
-54
lines changed

5 files changed

+563
-54
lines changed

src/spatialdata_plot/pl/basic.py

Lines changed: 162 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import sys
5+
import warnings
46
from collections import OrderedDict
57
from copy import deepcopy
68
from pathlib import Path
7-
from typing import Any, Literal
9+
from typing import Any, Literal, cast
810

911
import matplotlib.pyplot as plt
1012
import numpy as np
@@ -17,6 +19,7 @@
1719
from matplotlib.axes import Axes
1820
from matplotlib.colors import Colormap, Normalize
1921
from matplotlib.figure import Figure
22+
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
2023
from spatialdata import get_extent
2124
from spatialdata._utils import _deprecation_alias
2225
from xarray import DataArray, DataTree
@@ -28,9 +31,11 @@
2831
_render_labels,
2932
_render_points,
3033
_render_shapes,
34+
_split_colorbar_params,
3135
)
3236
from spatialdata_plot.pl.render_params import (
3337
CmapParams,
38+
ColorbarSpec,
3439
ImageRenderParams,
3540
LabelsRenderParams,
3641
LegendParams,
@@ -172,6 +177,8 @@ def render_shapes(
172177
table_name: str | None = None,
173178
table_layer: str | None = None,
174179
shape: Literal["circle", "hex", "visium_hex", "square"] | None = None,
180+
colorbar: bool | str | None = "auto",
181+
colorbar_params: dict[str, object] | None = None,
175182
**kwargs: Any,
176183
) -> sd.SpatialData:
177184
"""
@@ -237,6 +244,11 @@ def render_shapes(
237244
method : str | None, optional
238245
Whether to use 'matplotlib' and 'datashader'. When None, the method is
239246
chosen based on the size of the data.
247+
colorbar :
248+
Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
249+
colorbar_params :
250+
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
251+
and ``label``.
240252
table_name: str | None
241253
Name of the table containing the color(s) columns. If one name is given than the table is used for each
242254
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
@@ -292,6 +304,8 @@ def render_shapes(
292304
shape=shape,
293305
method=method,
294306
ds_reduction=kwargs.get("datashader_reduction"),
307+
colorbar=colorbar,
308+
colorbar_params=colorbar_params,
295309
)
296310

297311
sdata = self._copy()
@@ -326,6 +340,8 @@ def render_shapes(
326340
zorder=n_steps,
327341
method=param_values["method"],
328342
ds_reduction=param_values["ds_reduction"],
343+
colorbar=param_values["colorbar"],
344+
colorbar_params=param_values["colorbar_params"],
329345
)
330346
n_steps += 1
331347

@@ -347,6 +363,8 @@ def render_points(
347363
method: str | None = None,
348364
table_name: str | None = None,
349365
table_layer: str | None = None,
366+
colorbar: bool | str | None = "auto",
367+
colorbar_params: dict[str, object] | None = None,
350368
**kwargs: Any,
351369
) -> sd.SpatialData:
352370
"""
@@ -396,6 +414,11 @@ def render_points(
396414
method : str | None, optional
397415
Whether to use 'matplotlib' and 'datashader'. When None, the method is
398416
chosen based on the size of the data.
417+
colorbar :
418+
Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
419+
colorbar_params :
420+
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
421+
and ``label``.
399422
table_name: str | None
400423
Name of the table containing the color(s) columns. If one name is given than the table is used for each
401424
spatial element to be plotted if the table annotates it. If you want to use different tables for particular
@@ -434,6 +457,8 @@ def render_points(
434457
table_name=table_name,
435458
table_layer=table_layer,
436459
ds_reduction=kwargs.get("datashader_reduction"),
460+
colorbar=colorbar,
461+
colorbar_params=colorbar_params,
437462
)
438463

439464
if method is not None:
@@ -467,6 +492,8 @@ def render_points(
467492
zorder=n_steps,
468493
method=method,
469494
ds_reduction=param_values["ds_reduction"],
495+
colorbar=param_values["colorbar"],
496+
colorbar_params=param_values["colorbar_params"],
470497
)
471498
n_steps += 1
472499

@@ -484,6 +511,8 @@ def render_images(
484511
palette: list[str] | str | None = None,
485512
alpha: float | int = 1.0,
486513
scale: str | None = None,
514+
colorbar: bool | str | None = "auto",
515+
colorbar_params: dict[str, object] | None = None,
487516
**kwargs: Any,
488517
) -> sd.SpatialData:
489518
"""
@@ -526,6 +555,11 @@ def render_images(
526555
3) "full": Renders the full image without rasterization. In the case of
527556
multiscale images, the highest resolution scale is selected. Note that
528557
this may result in long computing times for large images.
558+
colorbar :
559+
Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
560+
colorbar_params :
561+
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
562+
and ``label``.
529563
kwargs
530564
Additional arguments to be passed to cmap, norm, and other rendering functions.
531565
@@ -547,6 +581,8 @@ def render_images(
547581
cmap=cmap,
548582
norm=norm,
549583
scale=scale,
584+
colorbar=colorbar,
585+
colorbar_params=colorbar_params,
550586
)
551587

552588
sdata = self._copy()
@@ -580,6 +616,8 @@ def render_images(
580616
alpha=param_values["alpha"],
581617
scale=param_values["scale"],
582618
zorder=n_steps,
619+
colorbar=param_values["colorbar"],
620+
colorbar_params=param_values["colorbar_params"],
583621
)
584622
n_steps += 1
585623

@@ -600,6 +638,8 @@ def render_labels(
600638
outline_alpha: float | int = 0.0,
601639
fill_alpha: float | int = 0.4,
602640
scale: str | None = None,
641+
colorbar: bool | str | None = "auto",
642+
colorbar_params: dict[str, object] | None = None,
603643
table_name: str | None = None,
604644
table_layer: str | None = None,
605645
**kwargs: Any,
@@ -653,6 +693,11 @@ def render_labels(
653693
(exception: a dpi is specified in `show()`. Then the image is rasterized to fit the canvas and dpi).
654694
3) "full": render the full image without rasterization. In the case of a multiscale image, the scale
655695
with the highest resolution is selected. This can lead to long computing times for large images!
696+
colorbar :
697+
Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
698+
colorbar_params :
699+
Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
700+
and ``label``.
656701
table_name: str | None
657702
Name of the table containing the color columns.
658703
table_layer: str | None
@@ -681,6 +726,8 @@ def render_labels(
681726
outline_alpha=outline_alpha,
682727
palette=palette,
683728
scale=scale,
729+
colorbar=colorbar,
730+
colorbar_params=colorbar_params,
684731
table_name=table_name,
685732
table_layer=table_layer,
686733
)
@@ -709,6 +756,8 @@ def render_labels(
709756
table_name=param_values["table_name"],
710757
table_layer=param_values["table_layer"],
711758
zorder=n_steps,
759+
colorbar=param_values["colorbar"],
760+
colorbar_params=param_values["colorbar_params"],
712761
)
713762
n_steps += 1
714763
return sdata
@@ -722,7 +771,8 @@ def show(
722771
legend_loc: str | None = "right margin",
723772
legend_fontoutline: int | None = None,
724773
na_in_legend: bool = True,
725-
colorbar: bool = True,
774+
colorbar: bool | None = None,
775+
colorbar_params: dict[str, object] | None = None,
726776
wspace: float | None = None,
727777
hspace: float = 0.25,
728778
ncols: int = 4,
@@ -761,7 +811,11 @@ def show(
761811
return_ax :
762812
Whether to return the axes object created. False by default.
763813
colorbar :
764-
Whether to plot the colorbar. True by default.
814+
DEPRECATED: use per-layer ``colorbar``/``colorbar_params`` instead. If provided, it will be honored
815+
but will emit a DeprecationWarning.
816+
colorbar_params :
817+
Global overrides passed to colorbars for all axes. Accepts the same keys as per-layer ``colorbar_params``
818+
(e.g., ``loc``, ``width``, ``pad``, ``label``).
765819
title :
766820
The title of the plot. If not provided the plot will have the name of the coordinate system as title.
767821
@@ -786,6 +840,7 @@ def show(
786840
legend_fontoutline,
787841
na_in_legend,
788842
colorbar,
843+
colorbar_params,
789844
wspace,
790845
hspace,
791846
ncols,
@@ -888,15 +943,89 @@ def show(
888943
ncols=ncols,
889944
frameon=frameon,
890945
)
946+
# colorbar is deprecated: warn when explicitly provided
947+
legend_colorbar = True if colorbar is None else colorbar
948+
if colorbar is not None:
949+
warnings.warn(
950+
"Parameter 'colorbar' in `.show()` is deprecated. "
951+
"Please control colorbars via the per-layer `colorbar` and `colorbar_params` arguments.",
952+
DeprecationWarning,
953+
stacklevel=2,
954+
)
955+
891956
legend_params = LegendParams(
892957
legend_fontsize=legend_fontsize,
893958
legend_fontweight=legend_fontweight,
894959
legend_loc=legend_loc,
895960
legend_fontoutline=legend_fontoutline,
896961
na_in_legend=na_in_legend,
897-
colorbar=colorbar,
962+
colorbar=legend_colorbar,
898963
)
899964

965+
def _draw_colorbar(spec: ColorbarSpec, location_offsets: dict[str, float]) -> None:
966+
base_layout = {"location": "right", "fraction": 0.08, "pad": 0.01}
967+
layer_layout, layer_kwargs, layer_label_override = _split_colorbar_params(spec.params)
968+
global_layout, global_kwargs, global_label_override = _split_colorbar_params(colorbar_params)
969+
layout = {**base_layout, **layer_layout, **global_layout}
970+
cbar_kwargs = {**layer_kwargs, **global_kwargs}
971+
972+
location = cast(str, layout.get("location", base_layout["location"]))
973+
default_orientation = "vertical" if location in {"right", "left"} else "horizontal"
974+
cbar_kwargs.setdefault("orientation", default_orientation)
975+
976+
fraction = float(cast(float | int, layout.get("fraction", base_layout["fraction"])))
977+
pad = float(cast(float | int, layout.get("pad", base_layout["pad"])))
978+
offset = location_offsets.get(location, 0.0)
979+
pad = pad + offset
980+
# update offset for the next bar on the same side (space consumed = pad + fraction)
981+
location_offsets[location] = offset + pad + fraction
982+
983+
if location == "right":
984+
bbox = [1 + pad, 0, fraction, 1]
985+
elif location == "left":
986+
bbox = [-(pad + fraction), 0, fraction, 1]
987+
elif location == "top":
988+
bbox = [0, 1 + pad, 1, fraction]
989+
elif location == "bottom":
990+
bbox = [0, -(pad + fraction), 1, fraction]
991+
else:
992+
bbox = [1 + pad, 0, fraction, 1]
993+
994+
cax = inset_axes(
995+
spec.ax,
996+
width="100%",
997+
height="100%",
998+
loc="center",
999+
bbox_to_anchor=bbox,
1000+
bbox_transform=spec.ax.transAxes,
1001+
borderpad=0.0,
1002+
)
1003+
1004+
cb = fig_params.fig.colorbar(spec.mappable, cax=cax, **cbar_kwargs)
1005+
if location == "left":
1006+
cb.ax.yaxis.set_ticks_position("left")
1007+
cb.ax.yaxis.set_label_position("left")
1008+
cb.ax.tick_params(labelleft=True, labelright=False)
1009+
elif location == "top":
1010+
cb.ax.xaxis.set_ticks_position("top")
1011+
cb.ax.xaxis.set_label_position("top")
1012+
cb.ax.tick_params(labeltop=True, labelbottom=False)
1013+
elif location == "right":
1014+
cb.ax.yaxis.set_ticks_position("right")
1015+
cb.ax.yaxis.set_label_position("right")
1016+
cb.ax.tick_params(labelright=True, labelleft=False)
1017+
elif location == "bottom":
1018+
cb.ax.xaxis.set_ticks_position("bottom")
1019+
cb.ax.xaxis.set_label_position("bottom")
1020+
cb.ax.tick_params(labelbottom=True, labeltop=False)
1021+
1022+
final_label = global_label_override or layer_label_override or spec.label
1023+
if final_label:
1024+
cb.set_label(final_label)
1025+
if spec.alpha is not None:
1026+
with contextlib.suppress(Exception):
1027+
cb.solids.set_alpha(spec.alpha)
1028+
9001029
cs_contents = _get_cs_contents(sdata)
9011030

9021031
# go through tree
@@ -908,6 +1037,8 @@ def show(
9081037
)
9091038
ax = fig_params.ax if fig_params.axs is None else fig_params.axs[i]
9101039
assert isinstance(ax, Axes)
1040+
axis_colorbar_requests: list[ColorbarSpec] | None = [] if legend_params.colorbar else None
1041+
location_offsets: dict[str, float] = {"left": 0.0, "right": 0.0, "top": 0.0, "bottom": 0.0}
9111042

9121043
wants_images = False
9131044
wants_labels = False
@@ -937,6 +1068,7 @@ def show(
9371068
fig_params=fig_params,
9381069
scalebar_params=scalebar_params,
9391070
legend_params=legend_params,
1071+
colorbar_requests=axis_colorbar_requests,
9401072
rasterize=rasterize,
9411073
)
9421074

@@ -954,6 +1086,7 @@ def show(
9541086
fig_params=fig_params,
9551087
scalebar_params=scalebar_params,
9561088
legend_params=legend_params,
1089+
colorbar_requests=axis_colorbar_requests,
9571090
)
9581091

9591092
elif cmd == "render_points" and has_points:
@@ -970,6 +1103,7 @@ def show(
9701103
fig_params=fig_params,
9711104
scalebar_params=scalebar_params,
9721105
legend_params=legend_params,
1106+
colorbar_requests=axis_colorbar_requests,
9731107
)
9741108

9751109
elif cmd == "render_labels" and has_labels:
@@ -978,7 +1112,8 @@ def show(
9781112
)
9791113

9801114
if wanted_labels_on_this_cs:
981-
if (table := params_copy.table_name) is not None:
1115+
table = params_copy.table_name
1116+
if table is not None:
9821117
assert isinstance(params_copy.color, str)
9831118
colors = sc.get.obs_df(sdata[table], [params_copy.color])
9841119
if isinstance(colors[params_copy.color].dtype, pd.CategoricalDtype):
@@ -1002,6 +1137,7 @@ def show(
10021137
fig_params=fig_params,
10031138
scalebar_params=scalebar_params,
10041139
legend_params=legend_params,
1140+
colorbar_requests=axis_colorbar_requests,
10051141
rasterize=rasterize,
10061142
)
10071143

@@ -1038,6 +1174,27 @@ def show(
10381174
ax.set_xlim(x_min, x_max)
10391175
ax.set_ylim(y_max, y_min) # (0, 0) is top-left
10401176

1177+
if legend_params.colorbar and axis_colorbar_requests:
1178+
# keep only unique bars per (location, label, layout+kwargs). Last request wins.
1179+
unique_specs_map: dict[tuple[Any, ...], ColorbarSpec] = {}
1180+
for spec in axis_colorbar_requests:
1181+
layer_layout, layer_kwargs, layer_label_override = _split_colorbar_params(spec.params)
1182+
global_layout, global_kwargs, global_label_override = _split_colorbar_params(colorbar_params)
1183+
layout = {**{"location": "right"}, **layer_layout, **global_layout}
1184+
kwargs_key = tuple(sorted({**layer_kwargs, **global_kwargs}.items()))
1185+
label_key = global_label_override or layer_label_override or spec.label
1186+
key = (
1187+
layout.get("location", "right"),
1188+
label_key,
1189+
tuple(sorted(layout.items())),
1190+
kwargs_key,
1191+
)
1192+
unique_specs_map[key] = spec
1193+
unique_specs = list(unique_specs_map.values())
1194+
1195+
for spec in unique_specs:
1196+
_draw_colorbar(spec, location_offsets)
1197+
10411198
if fig_params.fig is not None and save is not None:
10421199
save_fig(fig_params.fig, path=save)
10431200

0 commit comments

Comments
 (0)