11from __future__ import annotations
22
3+ import contextlib
34import sys
5+ import warnings
46from collections import OrderedDict
57from copy import deepcopy
68from pathlib import Path
7- from typing import Any , Literal
9+ from typing import Any , Literal , cast
810
911import matplotlib .pyplot as plt
1012import numpy as np
1719from matplotlib .axes import Axes
1820from matplotlib .colors import Colormap , Normalize
1921from matplotlib .figure import Figure
22+ from mpl_toolkits .axes_grid1 .inset_locator import inset_axes
2023from spatialdata import get_extent
2124from spatialdata ._utils import _deprecation_alias
2225from xarray import DataArray , DataTree
2831 _render_labels ,
2932 _render_points ,
3033 _render_shapes ,
34+ _split_colorbar_params ,
3135)
3236from spatialdata_plot .pl .render_params import (
3337 CmapParams ,
38+ ColorbarSpec ,
3439 ImageRenderParams ,
3540 LabelsRenderParams ,
3641 LegendParams ,
@@ -172,6 +177,8 @@ def render_shapes(
172177 table_name : str | None = None ,
173178 table_layer : str | None = None ,
174179 shape : Literal ["circle" , "hex" , "visium_hex" , "square" ] | None = None ,
180+ colorbar : bool | str | None = "auto" ,
181+ colorbar_params : dict [str , object ] | None = None ,
175182 ** kwargs : Any ,
176183 ) -> sd .SpatialData :
177184 """
@@ -237,6 +244,11 @@ def render_shapes(
237244 method : str | None, optional
238245 Whether to use 'matplotlib' and 'datashader'. When None, the method is
239246 chosen based on the size of the data.
247+ colorbar :
248+ Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
249+ colorbar_params :
250+ Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
251+ and ``label``.
240252 table_name: str | None
241253 Name of the table containing the color(s) columns. If one name is given than the table is used for each
242254 spatial element to be plotted if the table annotates it. If you want to use different tables for particular
@@ -292,6 +304,8 @@ def render_shapes(
292304 shape = shape ,
293305 method = method ,
294306 ds_reduction = kwargs .get ("datashader_reduction" ),
307+ colorbar = colorbar ,
308+ colorbar_params = colorbar_params ,
295309 )
296310
297311 sdata = self ._copy ()
@@ -326,6 +340,8 @@ def render_shapes(
326340 zorder = n_steps ,
327341 method = param_values ["method" ],
328342 ds_reduction = param_values ["ds_reduction" ],
343+ colorbar = param_values ["colorbar" ],
344+ colorbar_params = param_values ["colorbar_params" ],
329345 )
330346 n_steps += 1
331347
@@ -347,6 +363,8 @@ def render_points(
347363 method : str | None = None ,
348364 table_name : str | None = None ,
349365 table_layer : str | None = None ,
366+ colorbar : bool | str | None = "auto" ,
367+ colorbar_params : dict [str , object ] | None = None ,
350368 ** kwargs : Any ,
351369 ) -> sd .SpatialData :
352370 """
@@ -396,6 +414,11 @@ def render_points(
396414 method : str | None, optional
397415 Whether to use 'matplotlib' and 'datashader'. When None, the method is
398416 chosen based on the size of the data.
417+ colorbar :
418+ Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
419+ colorbar_params :
420+ Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
421+ and ``label``.
399422 table_name: str | None
400423 Name of the table containing the color(s) columns. If one name is given than the table is used for each
401424 spatial element to be plotted if the table annotates it. If you want to use different tables for particular
@@ -434,6 +457,8 @@ def render_points(
434457 table_name = table_name ,
435458 table_layer = table_layer ,
436459 ds_reduction = kwargs .get ("datashader_reduction" ),
460+ colorbar = colorbar ,
461+ colorbar_params = colorbar_params ,
437462 )
438463
439464 if method is not None :
@@ -467,6 +492,8 @@ def render_points(
467492 zorder = n_steps ,
468493 method = method ,
469494 ds_reduction = param_values ["ds_reduction" ],
495+ colorbar = param_values ["colorbar" ],
496+ colorbar_params = param_values ["colorbar_params" ],
470497 )
471498 n_steps += 1
472499
@@ -484,6 +511,8 @@ def render_images(
484511 palette : list [str ] | str | None = None ,
485512 alpha : float | int = 1.0 ,
486513 scale : str | None = None ,
514+ colorbar : bool | str | None = "auto" ,
515+ colorbar_params : dict [str , object ] | None = None ,
487516 ** kwargs : Any ,
488517 ) -> sd .SpatialData :
489518 """
@@ -526,6 +555,11 @@ def render_images(
526555 3) "full": Renders the full image without rasterization. In the case of
527556 multiscale images, the highest resolution scale is selected. Note that
528557 this may result in long computing times for large images.
558+ colorbar :
559+ Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
560+ colorbar_params :
561+ Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
562+ and ``label``.
529563 kwargs
530564 Additional arguments to be passed to cmap, norm, and other rendering functions.
531565
@@ -547,6 +581,8 @@ def render_images(
547581 cmap = cmap ,
548582 norm = norm ,
549583 scale = scale ,
584+ colorbar = colorbar ,
585+ colorbar_params = colorbar_params ,
550586 )
551587
552588 sdata = self ._copy ()
@@ -580,6 +616,8 @@ def render_images(
580616 alpha = param_values ["alpha" ],
581617 scale = param_values ["scale" ],
582618 zorder = n_steps ,
619+ colorbar = param_values ["colorbar" ],
620+ colorbar_params = param_values ["colorbar_params" ],
583621 )
584622 n_steps += 1
585623
@@ -600,6 +638,8 @@ def render_labels(
600638 outline_alpha : float | int = 0.0 ,
601639 fill_alpha : float | int = 0.4 ,
602640 scale : str | None = None ,
641+ colorbar : bool | str | None = "auto" ,
642+ colorbar_params : dict [str , object ] | None = None ,
603643 table_name : str | None = None ,
604644 table_layer : str | None = None ,
605645 ** kwargs : Any ,
@@ -653,6 +693,11 @@ def render_labels(
653693 (exception: a dpi is specified in `show()`. Then the image is rasterized to fit the canvas and dpi).
654694 3) "full": render the full image without rasterization. In the case of a multiscale image, the scale
655695 with the highest resolution is selected. This can lead to long computing times for large images!
696+ colorbar :
697+ Whether to request a colorbar for continuous colors. Use "auto" (default) for automatic selection.
698+ colorbar_params :
699+ Parameters forwarded to Matplotlib's colorbar alongside layout hints such as ``loc``, ``width``, ``pad``,
700+ and ``label``.
656701 table_name: str | None
657702 Name of the table containing the color columns.
658703 table_layer: str | None
@@ -681,6 +726,8 @@ def render_labels(
681726 outline_alpha = outline_alpha ,
682727 palette = palette ,
683728 scale = scale ,
729+ colorbar = colorbar ,
730+ colorbar_params = colorbar_params ,
684731 table_name = table_name ,
685732 table_layer = table_layer ,
686733 )
@@ -709,6 +756,8 @@ def render_labels(
709756 table_name = param_values ["table_name" ],
710757 table_layer = param_values ["table_layer" ],
711758 zorder = n_steps ,
759+ colorbar = param_values ["colorbar" ],
760+ colorbar_params = param_values ["colorbar_params" ],
712761 )
713762 n_steps += 1
714763 return sdata
@@ -722,7 +771,8 @@ def show(
722771 legend_loc : str | None = "right margin" ,
723772 legend_fontoutline : int | None = None ,
724773 na_in_legend : bool = True ,
725- colorbar : bool = True ,
774+ colorbar : bool | None = None ,
775+ colorbar_params : dict [str , object ] | None = None ,
726776 wspace : float | None = None ,
727777 hspace : float = 0.25 ,
728778 ncols : int = 4 ,
@@ -761,7 +811,11 @@ def show(
761811 return_ax :
762812 Whether to return the axes object created. False by default.
763813 colorbar :
764- Whether to plot the colorbar. True by default.
814+ DEPRECATED: use per-layer ``colorbar``/``colorbar_params`` instead. If provided, it will be honored
815+ but will emit a DeprecationWarning.
816+ colorbar_params :
817+ Global overrides passed to colorbars for all axes. Accepts the same keys as per-layer ``colorbar_params``
818+ (e.g., ``loc``, ``width``, ``pad``, ``label``).
765819 title :
766820 The title of the plot. If not provided the plot will have the name of the coordinate system as title.
767821
@@ -786,6 +840,7 @@ def show(
786840 legend_fontoutline ,
787841 na_in_legend ,
788842 colorbar ,
843+ colorbar_params ,
789844 wspace ,
790845 hspace ,
791846 ncols ,
@@ -888,15 +943,89 @@ def show(
888943 ncols = ncols ,
889944 frameon = frameon ,
890945 )
946+ # colorbar is deprecated: warn when explicitly provided
947+ legend_colorbar = True if colorbar is None else colorbar
948+ if colorbar is not None :
949+ warnings .warn (
950+ "Parameter 'colorbar' in `.show()` is deprecated. "
951+ "Please control colorbars via the per-layer `colorbar` and `colorbar_params` arguments." ,
952+ DeprecationWarning ,
953+ stacklevel = 2 ,
954+ )
955+
891956 legend_params = LegendParams (
892957 legend_fontsize = legend_fontsize ,
893958 legend_fontweight = legend_fontweight ,
894959 legend_loc = legend_loc ,
895960 legend_fontoutline = legend_fontoutline ,
896961 na_in_legend = na_in_legend ,
897- colorbar = colorbar ,
962+ colorbar = legend_colorbar ,
898963 )
899964
965+ def _draw_colorbar (spec : ColorbarSpec , location_offsets : dict [str , float ]) -> None :
966+ base_layout = {"location" : "right" , "fraction" : 0.08 , "pad" : 0.01 }
967+ layer_layout , layer_kwargs , layer_label_override = _split_colorbar_params (spec .params )
968+ global_layout , global_kwargs , global_label_override = _split_colorbar_params (colorbar_params )
969+ layout = {** base_layout , ** layer_layout , ** global_layout }
970+ cbar_kwargs = {** layer_kwargs , ** global_kwargs }
971+
972+ location = cast (str , layout .get ("location" , base_layout ["location" ]))
973+ default_orientation = "vertical" if location in {"right" , "left" } else "horizontal"
974+ cbar_kwargs .setdefault ("orientation" , default_orientation )
975+
976+ fraction = float (cast (float | int , layout .get ("fraction" , base_layout ["fraction" ])))
977+ pad = float (cast (float | int , layout .get ("pad" , base_layout ["pad" ])))
978+ offset = location_offsets .get (location , 0.0 )
979+ pad = pad + offset
980+ # update offset for the next bar on the same side (space consumed = pad + fraction)
981+ location_offsets [location ] = offset + pad + fraction
982+
983+ if location == "right" :
984+ bbox = [1 + pad , 0 , fraction , 1 ]
985+ elif location == "left" :
986+ bbox = [- (pad + fraction ), 0 , fraction , 1 ]
987+ elif location == "top" :
988+ bbox = [0 , 1 + pad , 1 , fraction ]
989+ elif location == "bottom" :
990+ bbox = [0 , - (pad + fraction ), 1 , fraction ]
991+ else :
992+ bbox = [1 + pad , 0 , fraction , 1 ]
993+
994+ cax = inset_axes (
995+ spec .ax ,
996+ width = "100%" ,
997+ height = "100%" ,
998+ loc = "center" ,
999+ bbox_to_anchor = bbox ,
1000+ bbox_transform = spec .ax .transAxes ,
1001+ borderpad = 0.0 ,
1002+ )
1003+
1004+ cb = fig_params .fig .colorbar (spec .mappable , cax = cax , ** cbar_kwargs )
1005+ if location == "left" :
1006+ cb .ax .yaxis .set_ticks_position ("left" )
1007+ cb .ax .yaxis .set_label_position ("left" )
1008+ cb .ax .tick_params (labelleft = True , labelright = False )
1009+ elif location == "top" :
1010+ cb .ax .xaxis .set_ticks_position ("top" )
1011+ cb .ax .xaxis .set_label_position ("top" )
1012+ cb .ax .tick_params (labeltop = True , labelbottom = False )
1013+ elif location == "right" :
1014+ cb .ax .yaxis .set_ticks_position ("right" )
1015+ cb .ax .yaxis .set_label_position ("right" )
1016+ cb .ax .tick_params (labelright = True , labelleft = False )
1017+ elif location == "bottom" :
1018+ cb .ax .xaxis .set_ticks_position ("bottom" )
1019+ cb .ax .xaxis .set_label_position ("bottom" )
1020+ cb .ax .tick_params (labelbottom = True , labeltop = False )
1021+
1022+ final_label = global_label_override or layer_label_override or spec .label
1023+ if final_label :
1024+ cb .set_label (final_label )
1025+ if spec .alpha is not None :
1026+ with contextlib .suppress (Exception ):
1027+ cb .solids .set_alpha (spec .alpha )
1028+
9001029 cs_contents = _get_cs_contents (sdata )
9011030
9021031 # go through tree
@@ -908,6 +1037,8 @@ def show(
9081037 )
9091038 ax = fig_params .ax if fig_params .axs is None else fig_params .axs [i ]
9101039 assert isinstance (ax , Axes )
1040+ axis_colorbar_requests : list [ColorbarSpec ] | None = [] if legend_params .colorbar else None
1041+ location_offsets : dict [str , float ] = {"left" : 0.0 , "right" : 0.0 , "top" : 0.0 , "bottom" : 0.0 }
9111042
9121043 wants_images = False
9131044 wants_labels = False
@@ -937,6 +1068,7 @@ def show(
9371068 fig_params = fig_params ,
9381069 scalebar_params = scalebar_params ,
9391070 legend_params = legend_params ,
1071+ colorbar_requests = axis_colorbar_requests ,
9401072 rasterize = rasterize ,
9411073 )
9421074
@@ -954,6 +1086,7 @@ def show(
9541086 fig_params = fig_params ,
9551087 scalebar_params = scalebar_params ,
9561088 legend_params = legend_params ,
1089+ colorbar_requests = axis_colorbar_requests ,
9571090 )
9581091
9591092 elif cmd == "render_points" and has_points :
@@ -970,6 +1103,7 @@ def show(
9701103 fig_params = fig_params ,
9711104 scalebar_params = scalebar_params ,
9721105 legend_params = legend_params ,
1106+ colorbar_requests = axis_colorbar_requests ,
9731107 )
9741108
9751109 elif cmd == "render_labels" and has_labels :
@@ -978,7 +1112,8 @@ def show(
9781112 )
9791113
9801114 if wanted_labels_on_this_cs :
981- if (table := params_copy .table_name ) is not None :
1115+ table = params_copy .table_name
1116+ if table is not None :
9821117 assert isinstance (params_copy .color , str )
9831118 colors = sc .get .obs_df (sdata [table ], [params_copy .color ])
9841119 if isinstance (colors [params_copy .color ].dtype , pd .CategoricalDtype ):
@@ -1002,6 +1137,7 @@ def show(
10021137 fig_params = fig_params ,
10031138 scalebar_params = scalebar_params ,
10041139 legend_params = legend_params ,
1140+ colorbar_requests = axis_colorbar_requests ,
10051141 rasterize = rasterize ,
10061142 )
10071143
@@ -1038,6 +1174,27 @@ def show(
10381174 ax .set_xlim (x_min , x_max )
10391175 ax .set_ylim (y_max , y_min ) # (0, 0) is top-left
10401176
1177+ if legend_params .colorbar and axis_colorbar_requests :
1178+ # keep only unique bars per (location, label, layout+kwargs). Last request wins.
1179+ unique_specs_map : dict [tuple [Any , ...], ColorbarSpec ] = {}
1180+ for spec in axis_colorbar_requests :
1181+ layer_layout , layer_kwargs , layer_label_override = _split_colorbar_params (spec .params )
1182+ global_layout , global_kwargs , global_label_override = _split_colorbar_params (colorbar_params )
1183+ layout = {** {"location" : "right" }, ** layer_layout , ** global_layout }
1184+ kwargs_key = tuple (sorted ({** layer_kwargs , ** global_kwargs }.items ()))
1185+ label_key = global_label_override or layer_label_override or spec .label
1186+ key = (
1187+ layout .get ("location" , "right" ),
1188+ label_key ,
1189+ tuple (sorted (layout .items ())),
1190+ kwargs_key ,
1191+ )
1192+ unique_specs_map [key ] = spec
1193+ unique_specs = list (unique_specs_map .values ())
1194+
1195+ for spec in unique_specs :
1196+ _draw_colorbar (spec , location_offsets )
1197+
10411198 if fig_params .fig is not None and save is not None :
10421199 save_fig (fig_params .fig , path = save )
10431200
0 commit comments