diff --git a/docs/colorbars_legends.py b/docs/colorbars_legends.py index 51ed495b4..e3e72f59b 100644 --- a/docs/colorbars_legends.py +++ b/docs/colorbars_legends.py @@ -469,6 +469,85 @@ ax = axs[1] ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows") axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo") + +# %% [raw] raw_mimetype="text/restructuredtext" +# .. _ug_semantic_legends: +# Semantic legends +# ---------------- +# +# Legends usually annotate artists already drawn on an axes, but sometimes you need +# standalone semantic keys (categories, size scales, color levels, or geometry types). +# UltraPlot provides helper methods that build these entries directly: +# +# * :meth:`~ultraplot.axes.Axes.catlegend` +# * :meth:`~ultraplot.axes.Axes.sizelegend` +# * :meth:`~ultraplot.axes.Axes.numlegend` +# * :meth:`~ultraplot.axes.Axes.geolegend` + +# %% +import cartopy.crs as ccrs +import shapely.geometry as sg + +fig, ax = uplt.subplots(refwidth=4.2) +ax.format(title="Semantic legend helpers", grid=False) + +ax.catlegend( + ["A", "B", "C"], + colors={"A": "red7", "B": "green7", "C": "blue7"}, + markers={"A": "o", "B": "s", "C": "^"}, + loc="top", + frameon=False, +) +ax.sizelegend( + [10, 50, 200], + loc="upper right", + title="Population", + ncols=1, + frameon=False, +) +ax.numlegend( + vmin=0, + vmax=1, + n=5, + cmap="viko", + fmt="{:.2f}", + loc="ll", + ncols=1, + frameon=False, +) + +poly1 = sg.Polygon([(0, 0), (2, 0), (1.2, 1.4)]) +ax.geolegend( + [ + ("Triangle", "triangle"), + ("Triangle-ish", poly1), + ("Australia", "country:AU"), + ("Netherlands (Mercator)", "country:NLD", "mercator"), + ( + "Netherlands (Lambert)", + "country:NLD", + { + "country_proj": ccrs.LambertConformal( + central_longitude=5, + central_latitude=52, + ), + "country_reso": "10m", + "country_territories": False, + "facecolor": "steelblue", + "fill": True, + }, + ), + ], + loc="r", + ncols=1, + handlesize=2.4, + handletextpad=0.35, + frameon=False, + country_reso="10m", +) +ax.axis("off") + + # %% [raw] raw_mimetype="text/restructuredtext" # .. _ug_guides_decouple: # diff --git a/docs/examples/legends_colorbars/03_semantic_legends.py b/docs/examples/legends_colorbars/03_semantic_legends.py new file mode 100644 index 000000000..c6bc7e9cc --- /dev/null +++ b/docs/examples/legends_colorbars/03_semantic_legends.py @@ -0,0 +1,91 @@ +""" +Semantic legends +================ + +Build legends from semantic mappings rather than existing artists. + +Why UltraPlot here? +------------------- +UltraPlot adds semantic legend helpers directly on axes: +``catlegend``, ``sizelegend``, ``numlegend``, and ``geolegend``. +These are useful when you want legend meaning decoupled from plotted handles. + +Key functions: :py:meth:`ultraplot.axes.Axes.catlegend`, :py:meth:`ultraplot.axes.Axes.sizelegend`, :py:meth:`ultraplot.axes.Axes.numlegend`, :py:meth:`ultraplot.axes.Axes.geolegend`. + +See also +-------- +* :doc:`Colorbars and legends ` +""" + +# %% +import cartopy.crs as ccrs +import numpy as np +import shapely.geometry as sg +from matplotlib.path import Path + +import ultraplot as uplt + +np.random.seed(0) +data = np.random.randn(2, 100) +sizes = np.random.randint(10, 512, data.shape[1]) +colors = np.random.rand(data.shape[1]) + +fig, ax = uplt.subplots() +ax.scatter(*data, color=colors, s=sizes, cmap="viko") +ax.format(title="Semantic legend helpers") + +ax.catlegend( + ["A", "B", "C"], + colors={"A": "red7", "B": "green7", "C": "blue7"}, + markers={"A": "o", "B": "s", "C": "^"}, + loc="top", + frameon=False, +) +ax.sizelegend( + [10, 50, 200], + loc="upper right", + title="Population", + ncols=1, + frameon=False, +) +ax.numlegend( + vmin=0, + vmax=1, + n=5, + cmap="viko", + fmt="{:.2f}", + loc="ll", + ncols=1, + frameon=False, +) + +poly1 = sg.Polygon([(0, 0), (2, 0), (1.2, 1.4)]) +ax.geolegend( + [ + ("Triangle", "triangle"), + ("Triangle-ish", poly1), + ("Australia", "country:AU"), + ("Netherlands (Mercator)", "country:NLD", "mercator"), + ( + "Netherlands (Lambert)", + "country:NLD", + { + "country_proj": ccrs.LambertConformal( + central_longitude=5, + central_latitude=52, + ), + "country_reso": "10m", + "country_territories": False, + "facecolor": "steelblue", + "fill": True, + }, + ), + ], + loc="r", + ncols=1, + handlesize=2.4, + handletextpad=0.35, + frameon=False, + country_reso="10m", +) +fig.show() diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 759995f71..e685259f5 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -2100,10 +2100,12 @@ def _parse_legend_centered( return objs @staticmethod - def _parse_legend_group(handles, labels=None): + def _parse_legend_group(handles, labels=None, handler_map=None): """ Parse possibly tuple-grouped input handles. """ + handler_map_full = plegend.Legend.get_default_handler_map().copy() + handler_map_full.update(handler_map or {}) # Helper function. Retrieve labels from a tuple group or from objects # in a container. Multiple labels lead to multiple legend entries. @@ -2154,7 +2156,18 @@ def _legend_tuple(*objs): # noqa: E306 continue handles.append(obj) else: - warnings._warn_ultraplot(f"Ignoring invalid legend handle {obj!r}.") + try: + handler = plegend.Legend.get_legend_handler( + handler_map_full, obj + ) + except Exception: + handler = None + if handler is not None: + handles.append(obj) + else: + warnings._warn_ultraplot( + f"Ignoring invalid legend handle {obj!r}." + ) return tuple(handles) # Sanitize labels. Ignore e.g. extra hist() or hist2d() return values, @@ -2247,7 +2260,9 @@ def _parse_legend_handles( ihandles, ilabels = to_list(ihandles), to_list(ilabels) if ihandles is None: ihandles = self._get_legend_handles(handler_map) - ihandles, ilabels = self._parse_legend_group(ihandles, ilabels) + ihandles, ilabels = self._parse_legend_group( + ihandles, ilabels, handler_map=handler_map + ) ipairs = list(zip(ihandles, ilabels)) if alphabetize: ipairs = sorted(ipairs, key=lambda pair: pair[1]) @@ -3487,6 +3502,64 @@ def legend( **kwargs, ) + def catlegend(self, categories, **kwargs): + """ + Build categorical legend entries and optionally add a legend. + + Parameters + ---------- + categories + Category labels used to generate legend handles. + **kwargs + Forwarded to `ultraplot.legend.UltraLegend.catlegend`. + Pass ``add=False`` to return ``(handles, labels)`` without drawing. + """ + return plegend.UltraLegend(self).catlegend(categories, **kwargs) + + def sizelegend(self, levels, **kwargs): + """ + Build size legend entries and optionally add a legend. + + Parameters + ---------- + levels + Numeric levels used to generate marker-size entries. + **kwargs + Forwarded to `ultraplot.legend.UltraLegend.sizelegend`. + Pass ``add=False`` to return ``(handles, labels)`` without drawing. + """ + return plegend.UltraLegend(self).sizelegend(levels, **kwargs) + + def numlegend(self, levels=None, **kwargs): + """ + Build numeric-color legend entries and optionally add a legend. + + Parameters + ---------- + levels + Numeric levels or number of levels. + **kwargs + Forwarded to `ultraplot.legend.UltraLegend.numlegend`. + Pass ``add=False`` to return ``(handles, labels)`` without drawing. + """ + return plegend.UltraLegend(self).numlegend(levels=levels, **kwargs) + + def geolegend(self, entries, labels=None, **kwargs): + """ + Build geometry legend entries and optionally add a legend. + + Parameters + ---------- + entries + Geometry entries (mapping, ``(label, geometry)`` pairs, or geometries). + labels + Optional labels for geometry sequences. + **kwargs + Forwarded to `ultraplot.legend.UltraLegend.geolegend`. + Pass ``add=False`` to return ``(handles, labels)`` without drawing. + """ + return plegend.UltraLegend(self).geolegend(entries, labels=labels, **kwargs) + @classmethod def _coerce_curve_xy(cls, x, y): """ diff --git a/ultraplot/internals/rcsetup.py b/ultraplot/internals/rcsetup.py index 1029c7a3d..9f580154b 100644 --- a/ultraplot/internals/rcsetup.py +++ b/ultraplot/internals/rcsetup.py @@ -1397,6 +1397,171 @@ def _validator_accepts(validator, value): _validate_bool, "Whether to add a shadow underneath inset colorbar frames.", ), + # Semantic legend helper defaults + "legend.cat.line": ( + False, + _validate_bool, + "Default line/marker mode for `Axes.catlegend`.", + ), + "legend.cat.marker": ( + "o", + _validate_string, + "Default marker for `Axes.catlegend` entries.", + ), + "legend.cat.linestyle": ( + "-", + _validate_linestyle, + "Default line style for `Axes.catlegend` entries.", + ), + "legend.cat.linewidth": ( + 2.0, + _validate_float, + "Default line width for `Axes.catlegend` entries.", + ), + "legend.cat.markersize": ( + 6.0, + _validate_float, + "Default marker size for `Axes.catlegend` entries.", + ), + "legend.cat.alpha": ( + None, + _validate_or_none(_validate_float), + "Default alpha for `Axes.catlegend` entries.", + ), + "legend.cat.markeredgecolor": ( + None, + _validate_or_none(_validate_color), + "Default marker edge color for `Axes.catlegend` entries.", + ), + "legend.cat.markeredgewidth": ( + None, + _validate_or_none(_validate_float), + "Default marker edge width for `Axes.catlegend` entries.", + ), + "legend.size.color": ( + "0.35", + _validate_color, + "Default marker color for `Axes.sizelegend` entries.", + ), + "legend.size.marker": ( + "o", + _validate_string, + "Default marker for `Axes.sizelegend` entries.", + ), + "legend.size.area": ( + True, + _validate_bool, + "Whether `Axes.sizelegend` interprets levels as marker area by default.", + ), + "legend.size.scale": ( + 1.0, + _validate_float, + "Default marker size scale factor for `Axes.sizelegend` entries.", + ), + "legend.size.minsize": ( + 3.0, + _validate_float, + "Default minimum marker size for `Axes.sizelegend` entries.", + ), + "legend.size.format": ( + None, + _validate_or_none(_validate_string), + "Default label format string for `Axes.sizelegend` entries.", + ), + "legend.size.alpha": ( + None, + _validate_or_none(_validate_float), + "Default alpha for `Axes.sizelegend` entries.", + ), + "legend.size.markeredgecolor": ( + None, + _validate_or_none(_validate_color), + "Default marker edge color for `Axes.sizelegend` entries.", + ), + "legend.size.markeredgewidth": ( + None, + _validate_or_none(_validate_float), + "Default marker edge width for `Axes.sizelegend` entries.", + ), + "legend.num.n": ( + 5, + _validate_int, + "Default number of sampled levels for `Axes.numlegend`.", + ), + "legend.num.cmap": ( + "viridis", + _validate_cmap("continuous"), + "Default colormap for `Axes.numlegend` entries.", + ), + "legend.num.edgecolor": ( + "none", + _validate_or_none(_validate_color), + "Default edge color for `Axes.numlegend` patch entries.", + ), + "legend.num.linewidth": ( + 0.0, + _validate_float, + "Default edge width for `Axes.numlegend` patch entries.", + ), + "legend.num.alpha": ( + None, + _validate_or_none(_validate_float), + "Default alpha for `Axes.numlegend` entries.", + ), + "legend.num.format": ( + None, + _validate_or_none(_validate_string), + "Default label format string for `Axes.numlegend` entries.", + ), + "legend.geo.facecolor": ( + "none", + _validate_or_none(_validate_color), + "Default face color for `Axes.geolegend` entries.", + ), + "legend.geo.edgecolor": ( + "0.25", + _validate_or_none(_validate_color), + "Default edge color for `Axes.geolegend` entries.", + ), + "legend.geo.linewidth": ( + 1.0, + _validate_float, + "Default edge width for `Axes.geolegend` entries.", + ), + "legend.geo.alpha": ( + None, + _validate_or_none(_validate_float), + "Default alpha for `Axes.geolegend` entries.", + ), + "legend.geo.fill": ( + None, + _validate_or_none(_validate_bool), + "Default fill mode for `Axes.geolegend` entries.", + ), + "legend.geo.country_reso": ( + "110m", + _validate_belongs("10m", "50m", "110m"), + "Default Natural Earth resolution used for country shorthand geometry " + "entries in `Axes.geolegend`.", + ), + "legend.geo.country_territories": ( + False, + _validate_bool, + "Whether country shorthand entries in `Axes.geolegend` include far-away " + "territories instead of pruning to the local footprint.", + ), + "legend.geo.country_proj": ( + None, + _validate_or_none(_validate_string), + "Optional projection name for country shorthand entries in `Axes.geolegend`. " + "Can be overridden per call with a cartopy CRS or callable.", + ), + "legend.geo.handlesize": ( + 1.0, + _validate_float, + "Scale factor applied to both legend handle length and height for " + "`Axes.geolegend` when explicit handle dimensions are not provided.", + ), # Color cycle additions "cycle": ( CYCLE, diff --git a/ultraplot/legend.py b/ultraplot/legend.py index da7781d48..baee6e13c 100644 --- a/ultraplot/legend.py +++ b/ultraplot/legend.py @@ -1,13 +1,20 @@ from dataclasses import dataclass +from functools import lru_cache from typing import Any, Iterable, Optional, Tuple, Union -import numpy as np import matplotlib.patches as mpatches +import matplotlib.path as mpath import matplotlib.text as mtext +import numpy as np +from matplotlib import cm as mcm +from matplotlib import colors as mcolors from matplotlib import lines as mlines from matplotlib import legend as mlegend from matplotlib import legend_handler as mhandler -from matplotlib import patches as mpatches + +from .config import rc +from .internals import _not_none, _pop_props, guides, rcsetup +from .utils import _fontsize_to_pt, units from .config import rc from .internals import _not_none, _pop_props, guides, rcsetup @@ -18,7 +25,29 @@ except ImportError: from typing_extensions import override -__all__ = ["Legend", "LegendEntry"] +try: # optional cartopy-dependent geometry support + import cartopy.crs as ccrs + from cartopy.io import shapereader as cshapereader + from cartopy.mpl.feature_artist import FeatureArtist as _CartopyFeatureArtist + from cartopy.mpl.path import shapely_to_path as _cartopy_shapely_to_path +except Exception: + ccrs = None + cshapereader = None + _CartopyFeatureArtist = None + _cartopy_shapely_to_path = None + +try: # optional shapely support for direct geometry legend handles + from shapely.geometry.base import BaseGeometry as _ShapelyBaseGeometry + from shapely.ops import unary_union as _shapely_unary_union +except Exception: + _ShapelyBaseGeometry = None + _shapely_unary_union = None + +__all__ = [ + "Legend", + "LegendEntry", + "GeometryEntry", +] def _wedge_legend_patch( @@ -104,6 +133,841 @@ def marker(cls, label=None, marker="o", **kwargs): return cls(label=label, line=False, marker=marker, **kwargs) +_GEOMETRY_SHAPE_PATHS = { + "circle": mpath.Path.unit_circle(), + "square": mpath.Path.unit_rectangle(), + "triangle": mpath.Path.unit_regular_polygon(3), + "diamond": mpath.Path.unit_regular_polygon(4), + "pentagon": mpath.Path.unit_regular_polygon(5), + "hexagon": mpath.Path.unit_regular_polygon(6), + "star": mpath.Path.unit_regular_star(5), +} +_GEOMETRY_SHAPE_ALIASES = { + "box": "square", + "rect": "square", + "rectangle": "square", + "tri": "triangle", + "pent": "pentagon", + "hex": "hexagon", +} +_DEFAULT_GEO_JOINSTYLE = "bevel" + + +def _normalize_shape_name(value: str) -> str: + """ + Normalize geometry shape shorthand names. + """ + key = str(value).strip().lower().replace("_", "").replace("-", "").replace(" ", "") + return _GEOMETRY_SHAPE_ALIASES.get(key, key) + + +def _normalize_country_resolution(resolution: str) -> str: + """ + Normalize Natural Earth shorthand resolution. + """ + value = str(resolution).strip().lower() + if value in {"10", "10m"}: + return "10m" + if value in {"50", "50m"}: + return "50m" + if value in {"110", "110m"}: + return "110m" + raise ValueError( + f"Invalid country resolution {resolution!r}. " + "Use one of: '10m', '50m', '110m'." + ) + + +def _country_geometry_for_legend(geometry: Any, *, include_far: bool = False) -> Any: + """ + Reduce multi-part country geometry for readability while preserving local islands. + + This avoids tiny legend glyphs for countries with distant overseas territories + (e.g., Netherlands in Natural Earth datasets), but tries to keep nearby islands. + """ + if include_far: + return geometry + geoms = getattr(geometry, "geoms", None) + if geoms is None: + return geometry + parts = [] + for part in geoms: + area = float(getattr(part, "area", 0.0) or 0.0) + if area > 0: + parts.append((area, part)) + if not parts: + return geometry + dominant = max(parts, key=lambda item: item[0])[1] + + # Preserve local components near the dominant polygon (e.g. nearby coastal islands) + # while dropping very distant territories that make legend glyphs too tiny. + minx, miny, maxx, maxy = dominant.bounds + span = max(maxx - minx, maxy - miny, 1e-6) + neighborhood = dominant.buffer(1.5 * span) + keep = [part for _, part in parts if part.intersects(neighborhood)] + if not keep: + return dominant + if len(keep) == 1: + return keep[0] + if _shapely_unary_union is None: + return dominant + try: + return _shapely_unary_union(keep) + except Exception: + return dominant + + +def _resolve_country_projection(country_proj: Any) -> Any: + """ + Resolve shorthand strings to cartopy projections for country legend geometries. + """ + if country_proj is None: + return None + if callable(country_proj) and not hasattr(country_proj, "project_geometry"): + return country_proj + if hasattr(country_proj, "project_geometry"): + return country_proj + if isinstance(country_proj, str): + if ccrs is None: + raise ValueError( + "country_proj requires cartopy. Install cartopy or pass a callable." + ) + key = ( + country_proj.strip() + .lower() + .replace("_", "") + .replace("-", "") + .replace(" ", "") + ) + mapping = { + "platecarree": ccrs.PlateCarree, + "pc": ccrs.PlateCarree, + "mercator": ccrs.Mercator, + "robinson": ccrs.Robinson, + "mollweide": ccrs.Mollweide, + "equalearth": ccrs.EqualEarth, + "orthographic": ccrs.Orthographic, + } + if key not in mapping: + raise ValueError( + f"Unknown country_proj {country_proj!r}. " + "Use a cartopy CRS, callable, or one of: " + + ", ".join(sorted(mapping)) + + "." + ) + # Orthographic needs center lon/lat. + if key == "orthographic": + return mapping[key](0, 0) + return mapping[key]() + raise ValueError( + "country_proj must be None, a cartopy CRS, a projection name string, or " + "a callable accepting and returning a geometry." + ) + + +def _project_geometry_for_legend(geometry: Any, country_proj: Any) -> Any: + """ + Project geometry for legend rendering when requested. + """ + projection = _resolve_country_projection(country_proj) + if projection is None: + return geometry + if callable(projection) and not hasattr(projection, "project_geometry"): + out = projection(geometry) + if out is None: + raise ValueError("country_proj callable returned None geometry.") + return out + if ccrs is None: + raise ValueError( + "country_proj cartopy projection requested but cartopy missing." + ) + try: + return projection.project_geometry(geometry, src_crs=ccrs.PlateCarree()) + except TypeError: + return projection.project_geometry(geometry, ccrs.PlateCarree()) + + +@lru_cache(maxsize=256) +def _resolve_country_geometry( + code: str, resolution: str = "110m", include_far: bool = False +): + """ + Resolve a country shorthand code (e.g., ``AU`` or ``AUS``) to a geometry. + """ + if cshapereader is None: + raise ValueError( + "Country shorthand requires cartopy's shapereader support. " + "Pass a shapely geometry directly instead." + ) + key = str(code).strip().upper() + if not key: + raise ValueError("Country shorthand cannot be empty.") + resolution = _normalize_country_resolution(resolution) + try: + path = cshapereader.natural_earth( + resolution=resolution, + category="cultural", + name="admin_0_countries", + ) + reader = cshapereader.Reader(path) + except Exception as exc: + raise ValueError( + "Unable to load Natural Earth country geometries for shorthand parsing. " + "This usually means cartopy data is not available offline yet. " + "Pass a shapely geometry directly (e.g. from GeoPandas), or pre-download " + "the Natural Earth dataset." + ) from exc + + fields = ( + "ADM0_A3", + "ISO_A3", + "ISO_A3_EH", + "SOV_A3", + "SU_A3", + "GU_A3", + "BRK_A3", + "ADM0_A3_US", + "ISO_A2", + "ISO_A2_EH", + "ABBREV", + "NAME", + "NAME_LONG", + "ADMIN", + ) + for record in reader.records(): + attrs = record.attributes or {} + values = {str(attrs.get(field, "")).strip().upper() for field in fields} + values.discard("") + if key in values: + return _country_geometry_for_legend( + record.geometry, include_far=include_far + ) + raise ValueError(f"Unknown country shorthand {code!r}.") + + +def _geometry_to_path( + geometry: Any, + *, + country_reso: str = "110m", + country_territories: bool = False, + country_proj: Any = None, +) -> mpath.Path: + """ + Convert geometry/path shorthand input to a matplotlib path. + """ + if isinstance(geometry, mpath.Path): + return geometry + if isinstance(geometry, str): + spec = geometry.strip() + shape = _normalize_shape_name(spec) + if shape in _GEOMETRY_SHAPE_PATHS: + return _GEOMETRY_SHAPE_PATHS[shape] + if spec.lower().startswith("country:"): + geometry = _resolve_country_geometry( + spec.split(":", 1)[1], + country_reso, + include_far=country_territories, + ) + geometry = _project_geometry_for_legend(geometry, country_proj) + elif spec.isalpha() and len(spec) in (2, 3): + geometry = _resolve_country_geometry( + spec, + country_reso, + include_far=country_territories, + ) + geometry = _project_geometry_for_legend(geometry, country_proj) + else: + options = ", ".join(sorted(_GEOMETRY_SHAPE_PATHS)) + raise ValueError( + f"Unknown geometry shorthand {geometry!r}. " + f"Use a shapely geometry, country code, or one of: {options}." + ) + if hasattr(geometry, "geom_type") and _cartopy_shapely_to_path is not None: + return _cartopy_shapely_to_path(geometry) + raise TypeError( + "Geometry must be a matplotlib Path, shapely geometry, geometry shorthand, " + "or country shorthand." + ) + + +def _fit_path_to_handlebox( + path: mpath.Path, + *, + xdescent: float, + ydescent: float, + width: float, + height: float, + pad: float = 0.08, +) -> mpath.Path: + """ + Normalize an arbitrary path into the legend-handle box. + """ + verts = np.array(path.vertices, copy=True, dtype=float) + finite = np.isfinite(verts).all(axis=1) + if not finite.any(): + return mpath.Path.unit_rectangle() + xmin, ymin = verts[finite].min(axis=0) + xmax, ymax = verts[finite].max(axis=0) + dx = max(float(xmax - xmin), 1e-12) + dy = max(float(ymax - ymin), 1e-12) + px = max(width * pad, 0.0) + py = max(height * pad, 0.0) + span_x = max(width - 2 * px, 1e-12) + span_y = max(height - 2 * py, 1e-12) + scale = min(span_x / dx, span_y / dy) + cx = -xdescent + width * 0.5 + cy = -ydescent + height * 0.5 + verts[finite, 0] = (verts[finite, 0] - (xmin + xmax) * 0.5) * scale + cx + verts[finite, 1] = (verts[finite, 1] - (ymin + ymax) * 0.5) * scale + cy + return mpath.Path( + verts, None if path.codes is None else np.array(path.codes, copy=True) + ) + + +def _feature_geometry_path(handle: Any) -> Optional[mpath.Path]: + """ + Extract the first geometry path from a cartopy feature artist. + """ + feature = getattr(handle, "_feature", None) + if feature is None or _cartopy_shapely_to_path is None: + return None + geoms = getattr(feature, "geometries", None) + if geoms is None: + return None + try: + iterator = iter(geoms()) + except Exception: + return None + try: + geometry = next(iterator) + except StopIteration: + return None + try: + return _cartopy_shapely_to_path(geometry) + except Exception: + return None + + +def _first_scalar(value: Any, default: Any = None) -> Any: + """ + Return first scalar from lists/arrays used by collection-style artists. + """ + if value is None: + return default + if isinstance(value, np.ndarray): + if value.size == 0: + return default + if value.ndim == 0: + return value.item() + if value.ndim >= 2: + item = value[0] + else: + item = value + if isinstance(item, np.ndarray) and item.size == 1: + return item.item() + return item + if isinstance(value, (list, tuple)): + if not value: + return default + item = value[0] + if isinstance(item, np.ndarray) and item.size == 1: + return item.item() + return item + return value + + +def _patch_joinstyle(value: Any, default: str = _DEFAULT_GEO_JOINSTYLE) -> str: + """ + Resolve patch joinstyle from artist methods/kwargs with a sensible default. + """ + getter = getattr(value, "get_joinstyle", None) + if callable(getter): + try: + joinstyle = getter() + except Exception: + joinstyle = None + if joinstyle: + return joinstyle + kwargs = getattr(value, "_kwargs", None) + if isinstance(kwargs, dict): + for key in ("joinstyle", "solid_joinstyle", "linejoin"): + joinstyle = kwargs.get(key, None) + if joinstyle: + return joinstyle + return default + + +def _feature_legend_patch( + legend, + orig_handle, + xdescent, + ydescent, + width, + height, + fontsize, +): + """ + Draw a normalized geometry path for cartopy feature artists. + """ + path = _feature_geometry_path(orig_handle) + if path is None: + path = mpath.Path.unit_rectangle() + path = _fit_path_to_handlebox( + path, + xdescent=xdescent, + ydescent=ydescent, + width=width, + height=height, + ) + return mpatches.PathPatch(path, joinstyle=_DEFAULT_GEO_JOINSTYLE) + + +def _shapely_geometry_patch( + legend, + orig_handle, + xdescent, + ydescent, + width, + height, + fontsize, +): + """ + Draw shapely geometry handles in legend boxes. + """ + if _cartopy_shapely_to_path is None: + path = mpath.Path.unit_rectangle() + else: + try: + path = _cartopy_shapely_to_path(orig_handle) + except Exception: + path = mpath.Path.unit_rectangle() + path = _fit_path_to_handlebox( + path, + xdescent=xdescent, + ydescent=ydescent, + width=width, + height=height, + ) + return mpatches.PathPatch(path, joinstyle=_DEFAULT_GEO_JOINSTYLE) + + +def _geometry_entry_patch( + legend, + orig_handle, + xdescent, + ydescent, + width, + height, + fontsize, +): + """ + Draw a geometry entry path inside the legend-handle box. + """ + path = _fit_path_to_handlebox( + orig_handle.get_path(), + xdescent=xdescent, + ydescent=ydescent, + width=width, + height=height, + ) + return mpatches.PathPatch(path, joinstyle=_DEFAULT_GEO_JOINSTYLE) + + +class _FeatureArtistLegendHandler(mhandler.HandlerPatch): + """ + Legend handler for cartopy FeatureArtist instances. + """ + + def __init__(self): + super().__init__(patch_func=_feature_legend_patch) + + def update_prop(self, legend_handle, orig_handle, legend): + facecolor = _first_scalar( + ( + orig_handle.get_facecolor() + if hasattr(orig_handle, "get_facecolor") + else None + ), + default="none", + ) + edgecolor = _first_scalar( + ( + orig_handle.get_edgecolor() + if hasattr(orig_handle, "get_edgecolor") + else None + ), + default="none", + ) + linewidth = _first_scalar( + ( + orig_handle.get_linewidth() + if hasattr(orig_handle, "get_linewidth") + else None + ), + default=0.0, + ) + legend_handle.set_facecolor(facecolor) + legend_handle.set_edgecolor(edgecolor) + legend_handle.set_linewidth(linewidth) + legend_handle.set_joinstyle(_patch_joinstyle(orig_handle)) + if hasattr(orig_handle, "get_alpha"): + legend_handle.set_alpha(orig_handle.get_alpha()) + legend._set_artist_props(legend_handle) + legend_handle.set_clip_box(None) + legend_handle.set_clip_path(None) + + +class _ShapelyGeometryLegendHandler(mhandler.HandlerPatch): + """ + Legend handler for raw shapely geometries. + """ + + def __init__(self): + super().__init__(patch_func=_shapely_geometry_patch) + + def update_prop(self, legend_handle, orig_handle, legend): + # No style information is stored on shapely geometry objects. + legend_handle.set_joinstyle(_DEFAULT_GEO_JOINSTYLE) + legend._set_artist_props(legend_handle) + legend_handle.set_clip_box(None) + legend_handle.set_clip_path(None) + + +class _GeometryEntryLegendHandler(mhandler.HandlerPatch): + """ + Legend handler for `GeometryEntry` custom handles. + """ + + def __init__(self): + super().__init__(patch_func=_geometry_entry_patch) + + def update_prop(self, legend_handle, orig_handle, legend): + super().update_prop(legend_handle, orig_handle, legend) + legend_handle.set_joinstyle(_patch_joinstyle(orig_handle)) + legend_handle.set_clip_box(None) + legend_handle.set_clip_path(None) + + +class GeometryEntry(mpatches.PathPatch): + """ + Convenience geometry legend entry. + + Parameters + ---------- + geometry + Geometry shorthand (e.g. ``'triangle'`` or ``'country:AU'``), + shapely geometry, or `matplotlib.path.Path`. + """ + + def __init__( + self, + geometry: Any = "square", + *, + country_reso: str = "110m", + country_territories: bool = False, + country_proj: Any = None, + label: Optional[str] = None, + facecolor: Any = "none", + edgecolor: Any = "0.25", + linewidth: float = 1.0, + joinstyle: str = _DEFAULT_GEO_JOINSTYLE, + alpha: Optional[float] = None, + fill: Optional[bool] = None, + **kwargs: Any, + ): + path = _geometry_to_path( + geometry, + country_reso=country_reso, + country_territories=country_territories, + country_proj=country_proj, + ) + if fill is None: + fill = facecolor not in (None, "none") + super().__init__( + path=path, + label=label, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + joinstyle=joinstyle, + alpha=alpha, + fill=fill, + **kwargs, + ) + self._ultraplot_geometry = geometry + + +def _geometry_default_label(geometry: Any, index: int) -> str: + """ + Derive default labels for geo legend entries. + """ + if isinstance(geometry, str): + return geometry + return f"Entry {index + 1}" + + +def _geo_legend_entries( + entries: Iterable[Any] | dict[Any, Any], + labels: Optional[Iterable[Any]] = None, + *, + country_reso: str = "110m", + country_territories: bool = False, + country_proj: Any = None, + facecolor: Any = "none", + edgecolor: Any = "0.25", + linewidth: float = 1.0, + alpha: Optional[float] = None, + fill: Optional[bool] = None, +): + """ + Build geometry semantic legend handles and labels. + + Notes + ----- + `entries` may be: + - mapping of ``label -> geometry`` + - sequence of ``(label, geometry)`` or ``(label, geometry, options)`` tuples + where ``options`` is either a projection spec or a dict of per-entry + `GeometryEntry` keyword overrides (e.g., `country_proj`, `country_reso`) + - sequence of geometries with explicit `labels` + """ + entry_options = None + if isinstance(entries, dict): + label_list = [str(label) for label in entries] + geometry_list = list(entries.values()) + entry_options = [{} for _ in geometry_list] + else: + entries = list(entries) + if labels is None and all( + isinstance(entry, tuple) and len(entry) in (2, 3) for entry in entries + ): + label_list = [] + geometry_list = [] + entry_options = [] + for entry in entries: + if len(entry) == 2: + label, geometry = entry + options = {} + else: + label, geometry, options = entry + if options is None: + options = {} + elif isinstance(options, dict): + options = dict(options) + else: + # Convenience shorthand for per-entry projection only. + options = {"country_proj": options} + label_list.append(str(label)) + geometry_list.append(geometry) + entry_options.append(options) + else: + geometry_list = list(entries) + entry_options = [{} for _ in geometry_list] + if labels is None: + label_list = [ + _geometry_default_label(geometry, idx) + for idx, geometry in enumerate(geometry_list) + ] + else: + label_list = [str(label) for label in labels] + if len(label_list) != len(geometry_list): + raise ValueError( + "Labels and geometry entries must have the same length. " + f"Got {len(label_list)} labels and {len(geometry_list)} entries." + ) + handles = [] + for geometry, label, options in zip(geometry_list, label_list, entry_options): + geo_kwargs = { + "country_reso": country_reso, + "country_territories": country_territories, + "country_proj": country_proj, + "facecolor": facecolor, + "edgecolor": edgecolor, + "linewidth": linewidth, + "alpha": alpha, + "fill": fill, + } + geo_kwargs.update(options or {}) + handles.append(GeometryEntry(geometry, label=label, **geo_kwargs)) + return handles, label_list + + +def _style_lookup(style, key, index, default=None): + """ + Resolve style values from scalar, mapping, or sequence inputs. + """ + if style is None: + return default + if isinstance(style, dict): + return style.get(key, default) + if isinstance(style, str): + return style + try: + values = list(style) + except TypeError: + return style + if not values: + return default + return values[index % len(values)] + + +def _format_label(value, fmt): + """ + Format legend labels from values. + """ + if fmt is None: + return f"{value:g}" if isinstance(value, (float, np.floating)) else str(value) + if callable(fmt): + return str(fmt(value)) + return fmt.format(value) + + +def _default_cycle_colors(): + """ + Return default color cycle entries. + """ + try: + import matplotlib as mpl + + colors = mpl.rcParams["axes.prop_cycle"].by_key().get("color", None) + except Exception: + colors = None + return colors or ["C0"] + + +def _cat_legend_entries( + categories: Iterable[Any], + *, + colors=None, + markers="o", + line: bool = False, + linestyle: str = "-", + linewidth: float = 2.0, + markersize: float = 6.0, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, +): + """ + Build categorical semantic legend handles and labels. + """ + labels = list(dict.fromkeys(categories)) + palette = _default_cycle_colors() + handles = [] + for idx, label in enumerate(labels): + color = _style_lookup(colors, label, idx, default=palette[idx % len(palette)]) + marker = _style_lookup(markers, label, idx, default="o") + if line and marker in (None, ""): + marker = None + handles.append( + LegendEntry( + label=str(label), + color=color, + line=line, + marker=marker, + linestyle=linestyle, + linewidth=linewidth, + markersize=markersize, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + alpha=alpha, + ) + ) + return handles, [str(label) for label in labels] + + +def _size_legend_entries( + levels: Iterable[float], + *, + color="0.35", + marker: str = "o", + area: bool = True, + scale: float = 1.0, + minsize: float = 3.0, + fmt=None, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, +): + """ + Build size semantic legend handles and labels. + """ + values = np.asarray(list(levels), dtype=float) + if values.size == 0: + return [], [] + if area: + ms = np.sqrt(np.clip(values, 0, None)) + else: + ms = np.abs(values) + ms = np.maximum(ms * scale, minsize) + labels = [_format_label(value, fmt) for value in values] + handles = [ + LegendEntry.marker( + label=label, + marker=marker, + color=color, + markersize=float(size), + alpha=alpha, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + ) + for label, size in zip(labels, ms) + ] + return handles, labels + + +def _num_legend_entries( + levels=None, + *, + vmin=None, + vmax=None, + n: int = 5, + cmap="viridis", + norm=None, + fmt=None, + edgecolor="none", + linewidth: float = 0.0, + alpha=None, +): + """ + Build numeric-color semantic legend handles and labels. + """ + if levels is None: + if vmin is None or vmax is None: + raise ValueError("Please provide levels or both vmin and vmax.") + values = np.linspace(float(vmin), float(vmax), int(n)) + elif np.isscalar(levels) and isinstance(levels, (int, np.integer)): + if vmin is None or vmax is None: + raise ValueError("Please provide vmin and vmax when levels is an integer.") + values = np.linspace(float(vmin), float(vmax), int(levels)) + else: + values = np.asarray(list(levels), dtype=float) + if values.size == 0: + return [], [] + if norm is None: + lo = float(np.nanmin(values) if vmin is None else vmin) + hi = float(np.nanmax(values) if vmax is None else vmax) + norm = mcolors.Normalize(vmin=lo, vmax=hi) + try: + import matplotlib as mpl + + cmap_obj = mpl.colormaps.get_cmap(cmap) + except Exception: + cmap_obj = mcm.get_cmap(cmap) + labels = [_format_label(value, fmt) for value in values] + handles = [ + mpatches.Patch( + facecolor=cmap_obj(norm(float(value))), + edgecolor=edgecolor, + linewidth=linewidth, + alpha=alpha, + label=label, + ) + for value, label in zip(values, labels) + ] + return handles, labels + + ALIGN_OPTS = { None: { "center": "center", @@ -192,10 +1056,20 @@ def get_default_handler_map(cls): Extend matplotlib defaults with a wedge handler for pie legends. """ handler_map = dict(super().get_default_handler_map()) + handler_map.setdefault( + GeometryEntry, + _GeometryEntryLegendHandler(), + ) handler_map.setdefault( mpatches.Wedge, mhandler.HandlerPatch(patch_func=_wedge_legend_patch), ) + if _CartopyFeatureArtist is not None: + handler_map.setdefault(_CartopyFeatureArtist, _FeatureArtistLegendHandler()) + if _ShapelyBaseGeometry is not None: + handler_map.setdefault( + _ShapelyBaseGeometry, _ShapelyGeometryLegendHandler() + ) return handler_map @override @@ -241,6 +1115,211 @@ class UltraLegend: def __init__(self, axes): self.axes = axes + @staticmethod + def _validate_semantic_kwargs(method: str, kwargs: dict[str, Any]) -> None: + """ + Prevent ambiguous legend kwargs for semantic legend helpers. + """ + if "label" in kwargs: + raise TypeError( + f"{method}() does not accept the legend kwarg 'label'. " + "Use title=... for the legend title." + ) + if "labels" in kwargs: + raise TypeError( + f"{method}() does not accept the legend kwarg 'labels'. " + "Semantic legend labels are derived from the helper inputs." + ) + + def catlegend( + self, + categories: Iterable[Any], + *, + colors=None, + markers=None, + line: Optional[bool] = None, + linestyle=None, + linewidth: Optional[float] = None, + markersize: Optional[float] = None, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, + add: bool = True, + **legend_kwargs: Any, + ): + """ + Build categorical legend entries and optionally draw a legend. + """ + line = _not_none(line, rc["legend.cat.line"]) + markers = _not_none(markers, rc["legend.cat.marker"]) + linestyle = _not_none(linestyle, rc["legend.cat.linestyle"]) + linewidth = _not_none(linewidth, rc["legend.cat.linewidth"]) + markersize = _not_none(markersize, rc["legend.cat.markersize"]) + alpha = _not_none(alpha, rc["legend.cat.alpha"]) + markeredgecolor = _not_none(markeredgecolor, rc["legend.cat.markeredgecolor"]) + markeredgewidth = _not_none(markeredgewidth, rc["legend.cat.markeredgewidth"]) + handles, labels = _cat_legend_entries( + categories, + colors=colors, + markers=markers, + line=line, + linestyle=linestyle, + linewidth=linewidth, + markersize=markersize, + alpha=alpha, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + ) + if not add: + return handles, labels + self._validate_semantic_kwargs("catlegend", legend_kwargs) + # Route through Axes.legend so location shorthands (e.g. 'r', 'b') + # and queued guide keyword handling behave exactly like the public API. + return self.axes.legend(handles, labels, **legend_kwargs) + + def sizelegend( + self, + levels: Iterable[float], + *, + color=None, + marker=None, + area: Optional[bool] = None, + scale: Optional[float] = None, + minsize: Optional[float] = None, + fmt=None, + alpha=None, + markeredgecolor=None, + markeredgewidth=None, + add: bool = True, + **legend_kwargs: Any, + ): + """ + Build size legend entries and optionally draw a legend. + """ + color = _not_none(color, rc["legend.size.color"]) + marker = _not_none(marker, rc["legend.size.marker"]) + area = _not_none(area, rc["legend.size.area"]) + scale = _not_none(scale, rc["legend.size.scale"]) + minsize = _not_none(minsize, rc["legend.size.minsize"]) + fmt = _not_none(fmt, rc["legend.size.format"]) + alpha = _not_none(alpha, rc["legend.size.alpha"]) + markeredgecolor = _not_none(markeredgecolor, rc["legend.size.markeredgecolor"]) + markeredgewidth = _not_none(markeredgewidth, rc["legend.size.markeredgewidth"]) + handles, labels = _size_legend_entries( + levels, + color=color, + marker=marker, + area=area, + scale=scale, + minsize=minsize, + fmt=fmt, + alpha=alpha, + markeredgecolor=markeredgecolor, + markeredgewidth=markeredgewidth, + ) + if not add: + return handles, labels + self._validate_semantic_kwargs("sizelegend", legend_kwargs) + return self.axes.legend(handles, labels, **legend_kwargs) + + def numlegend( + self, + levels=None, + *, + vmin=None, + vmax=None, + n: Optional[int] = None, + cmap=None, + norm=None, + fmt=None, + edgecolor=None, + linewidth: Optional[float] = None, + alpha=None, + add: bool = True, + **legend_kwargs: Any, + ): + """ + Build numeric-color legend entries and optionally draw a legend. + """ + n = _not_none(n, rc["legend.num.n"]) + cmap = _not_none(cmap, rc["legend.num.cmap"]) + edgecolor = _not_none(edgecolor, rc["legend.num.edgecolor"]) + linewidth = _not_none(linewidth, rc["legend.num.linewidth"]) + alpha = _not_none(alpha, rc["legend.num.alpha"]) + fmt = _not_none(fmt, rc["legend.num.format"]) + handles, labels = _num_legend_entries( + levels=levels, + vmin=vmin, + vmax=vmax, + n=n, + cmap=cmap, + norm=norm, + fmt=fmt, + edgecolor=edgecolor, + linewidth=linewidth, + alpha=alpha, + ) + if not add: + return handles, labels + self._validate_semantic_kwargs("numlegend", legend_kwargs) + return self.axes.legend(handles, labels, **legend_kwargs) + + def geolegend( + self, + entries: Iterable[Any] | dict[Any, Any], + labels: Optional[Iterable[Any]] = None, + *, + country_reso: Optional[str] = None, + country_territories: Optional[bool] = None, + country_proj: Any = None, + handlesize: Optional[float] = None, + facecolor: Any = None, + edgecolor: Any = None, + linewidth: Optional[float] = None, + alpha: Optional[float] = None, + fill: Optional[bool] = None, + add: bool = True, + **legend_kwargs: Any, + ): + """ + Build geometry legend entries and optionally draw a legend. + """ + facecolor = _not_none(facecolor, rc["legend.geo.facecolor"]) + edgecolor = _not_none(edgecolor, rc["legend.geo.edgecolor"]) + linewidth = _not_none(linewidth, rc["legend.geo.linewidth"]) + alpha = _not_none(alpha, rc["legend.geo.alpha"]) + fill = _not_none(fill, rc["legend.geo.fill"]) + country_reso = _not_none(country_reso, rc["legend.geo.country_reso"]) + country_territories = _not_none( + country_territories, rc["legend.geo.country_territories"] + ) + country_proj = _not_none(country_proj, rc["legend.geo.country_proj"]) + handlesize = _not_none(handlesize, rc["legend.geo.handlesize"]) + handles, labels = _geo_legend_entries( + entries, + labels=labels, + country_reso=country_reso, + country_territories=country_territories, + country_proj=country_proj, + facecolor=facecolor, + edgecolor=edgecolor, + linewidth=linewidth, + alpha=alpha, + fill=fill, + ) + if not add: + return handles, labels + self._validate_semantic_kwargs("geolegend", legend_kwargs) + if handlesize is not None: + handlesize = float(handlesize) + if handlesize <= 0: + raise ValueError("geolegend handlesize must be positive.") + if "handlelength" not in legend_kwargs: + legend_kwargs["handlelength"] = rc["legend.handlelength"] * handlesize + if "handleheight" not in legend_kwargs: + legend_kwargs["handleheight"] = rc["legend.handleheight"] * handlesize + return self.axes.legend(handles, labels, **legend_kwargs) + @staticmethod def _align_map() -> dict[Optional[str], dict[str, str]]: """ @@ -560,86 +1639,3 @@ def add( self._apply_handle_styles(objs, kw_text=kw_text, kw_handle=kw_handle) return self._finalize(objs, loc=inputs.loc, align=inputs.align) - - # Handle and text properties that are applied after-the-fact - # NOTE: Set solid_capstyle to 'butt' so line does not extend past error bounds - # shading in legend entry. This change is not noticable in other situations. - kw_frame, kwargs = lax._parse_frame("legend", **kwargs) - kw_text = {} - if fontcolor is not None: - kw_text["color"] = fontcolor - if fontweight is not None: - kw_text["weight"] = fontweight - kw_title = {} - if titlefontcolor is not None: - kw_title["color"] = titlefontcolor - if titlefontweight is not None: - kw_title["weight"] = titlefontweight - kw_handle = _pop_props(kwargs, "line") - kw_handle.setdefault("solid_capstyle", "butt") - kw_handle.update(handle_kw or {}) - - # Parse the legend arguments using axes for auto-handle detection - # TODO: Update this when we no longer use "filled panels" for outer legends - pairs, multi = lax._parse_legend_handles( - handles, - labels, - ncol=ncol, - order=order, - center=center, - alphabetize=alphabetize, - handler_map=handler_map, - ) - title = _not_none(label=label, title=title) - kwargs.update( - { - "title": title, - "frameon": frameon, - "fontsize": fontsize, - "handler_map": handler_map, - "title_fontsize": titlefontsize, - } - ) - - # Add the legend and update patch properties - # TODO: Add capacity for categorical labels in a single legend like seaborn - # rather than manual handle overrides with multiple legends. - if multi: - objs = lax._parse_legend_centered(pairs, kw_frame=kw_frame, **kwargs) - else: - kwargs.update({key: kw_frame.pop(key) for key in ("shadow", "fancybox")}) - objs = [lax._parse_legend_aligned(pairs, ncol=ncol, order=order, **kwargs)] - objs[0].legendPatch.update(kw_frame) - for obj in objs: - if hasattr(lax, "legend_") and lax.legend_ is None: - lax.legend_ = obj # make first legend accessible with get_legend() - else: - lax.add_artist(obj) - - # Update legend patch and elements - # WARNING: legendHandles only contains the *first* artist per legend because - # HandlerBase.legend_artist() called in Legend._init_legend_box() only - # returns the first artist. Instead we try to iterate through offset boxes. - for obj in objs: - obj.set_clip_on(False) # needed for tight bounding box calculations - box = getattr(obj, "_legend_handle_box", None) - for child in guides._iter_children(box): - if isinstance(child, mtext.Text): - kw = kw_text - else: - kw = { - key: val - for key, val in kw_handle.items() - if hasattr(child, "set_" + key) - } - if hasattr(child, "set_sizes") and "markersize" in kw_handle: - kw["sizes"] = np.atleast_1d(kw_handle["markersize"]) - child.update(kw) - - # Register location and return - if isinstance(objs[0], mpatches.FancyBboxPatch): - objs = objs[1:] - obj = objs[0] if len(objs) == 1 else tuple(objs) - ax._register_guide("legend", obj, (loc, align)) # possibly replace another - - return obj diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index e04287286..217417860 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest +from matplotlib import colors as mcolors from matplotlib import legend_handler as mhandler from matplotlib import patches as mpatches @@ -326,6 +327,371 @@ def test_legend_entry_with_axes_legend(): uplt.close(fig) +def test_semantic_helpers_not_public_on_module(): + for name in ("catlegend", "sizelegend", "numlegend", "geolegend"): + assert not hasattr(uplt, name) + + +def test_geo_legend_helper_shapes(): + fig, ax = uplt.subplots() + handles, labels = ax.geolegend( + [("Triangle", "triangle"), ("Hex", "hexagon")], add=False + ) + assert labels == ["Triangle", "Hex"] + assert len(handles) == 2 + assert all(isinstance(handle, mpatches.PathPatch) for handle in handles) + uplt.close(fig) + + +def test_semantic_legend_rc_defaults(): + fig, axs = uplt.subplots(ncols=4, share=False) + with uplt.rc.context( + { + "legend.cat.line": True, + "legend.cat.marker": "s", + "legend.cat.linewidth": 3.25, + "legend.size.marker": "^", + "legend.size.minsize": 8.0, + "legend.num.n": 3, + "legend.geo.facecolor": "red7", + "legend.geo.edgecolor": "black", + "legend.geo.fill": True, + } + ): + leg = axs[0].catlegend(["A"], loc="best") + h = leg.legend_handles[0] + assert h.get_marker() == "s" + assert h.get_linewidth() == pytest.approx(3.25) + + leg = axs[1].sizelegend([1.0], loc="best") + h = leg.legend_handles[0] + assert h.get_marker() == "^" + assert h.get_markersize() >= 8.0 + + leg = axs[2].numlegend(vmin=0, vmax=1, loc="best") + assert len(leg.legend_handles) == 3 + + leg = axs[3].geolegend([("shape", "triangle")], loc="best") + h = leg.legend_handles[0] + assert isinstance(h, mpatches.PathPatch) + assert np.allclose(h.get_facecolor(), mcolors.to_rgba("red7")) + uplt.close(fig) + + +def test_semantic_legend_loc_shorthand(): + fig, ax = uplt.subplots() + leg = ax.catlegend(["A", "B"], loc="r") + assert leg is not None + assert [text.get_text() for text in leg.get_texts()] == ["A", "B"] + uplt.close(fig) + + +@pytest.mark.parametrize( + "builder, args, kwargs", + ( + ("catlegend", (["A", "B"],), {}), + ("sizelegend", ([10, 50],), {}), + ("numlegend", tuple(), {"vmin": 0, "vmax": 1}), + ("geolegend", ([("shape", "triangle")],), {}), + ), +) +def test_semantic_legend_rejects_label_kwarg(builder, args, kwargs): + fig, ax = uplt.subplots() + method = getattr(ax, builder) + with pytest.raises(TypeError, match="Use title=\\.\\.\\. for the legend title"): + method(*args, label="Legend", **kwargs) + uplt.close(fig) + + +@pytest.mark.parametrize( + "builder, args, kwargs", + ( + ("catlegend", (["A", "B"],), {}), + ("sizelegend", ([10, 50],), {}), + ("numlegend", tuple(), {"vmin": 0, "vmax": 1}), + ), +) +def test_semantic_legend_rejects_labels_kwarg(builder, args, kwargs): + fig, ax = uplt.subplots() + method = getattr(ax, builder) + with pytest.raises(TypeError, match="does not accept the legend kwarg 'labels'"): + method(*args, labels=["x", "y"], **kwargs) + uplt.close(fig) + + +def test_geo_legend_handlesize_scales_handle_box(): + fig, ax = uplt.subplots() + leg = ax.geolegend([("shape", "triangle")], loc="best", handlesize=2.0) + assert leg.handlelength == pytest.approx(2.0 * uplt.rc["legend.handlelength"]) + assert leg.handleheight == pytest.approx(2.0 * uplt.rc["legend.handleheight"]) + + with uplt.rc.context({"legend.geo.handlesize": 1.5}): + leg = ax.geolegend([("shape", "triangle")], loc="best") + assert leg.handlelength == pytest.approx(1.5 * uplt.rc["legend.handlelength"]) + assert leg.handleheight == pytest.approx(1.5 * uplt.rc["legend.handleheight"]) + uplt.close(fig) + + +def test_geo_legend_helper_with_axes_legend(monkeypatch): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + monkeypatch.setattr( + plegend, + "_resolve_country_geometry", + lambda _, resolution="110m", include_far=False: sgeom.box(-1, -1, 1, 1), + ) + fig, ax = uplt.subplots() + leg = ax.geolegend({"AUS": "country:AU", "NZL": "country:NZ"}, loc="best") + assert [text.get_text() for text in leg.get_texts()] == ["AUS", "NZL"] + uplt.close(fig) + + +def test_geo_legend_country_resolution_passthrough(monkeypatch): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + calls = [] + + def _fake_country(code, resolution="110m", include_far=False): + calls.append((str(code).upper(), resolution, bool(include_far))) + return sgeom.box(-1, -1, 1, 1) + + monkeypatch.setattr(plegend, "_resolve_country_geometry", _fake_country) + + fig, ax = uplt.subplots() + ax.geolegend([("NLD", "country:NLD")], country_reso="10m", add=False) + assert calls == [("NLD", "10m", False)] + + calls.clear() + with uplt.rc.context({"legend.geo.country_reso": "50m"}): + ax.geolegend([("NLD", "country:NLD")], add=False) + assert calls == [("NLD", "50m", False)] + + calls.clear() + ax.geolegend([("NLD", "country:NLD")], country_territories=True, add=False) + assert calls == [("NLD", "110m", True)] + + calls.clear() + with uplt.rc.context({"legend.geo.country_territories": True}): + ax.geolegend([("NLD", "country:NLD")], add=False) + assert calls == [("NLD", "110m", True)] + uplt.close(fig) + + +def test_geo_legend_country_projection_passthrough(monkeypatch): + sgeom = pytest.importorskip("shapely.geometry") + from shapely import affinity + from ultraplot import legend as plegend + + monkeypatch.setattr( + plegend, + "_resolve_country_geometry", + lambda code, resolution="110m", include_far=False: sgeom.box(0, 0, 2, 1), + ) + fig, ax = uplt.subplots() + handles0, _ = ax.geolegend([("NLD", "country:NLD")], add=False) + handles1, _ = ax.geolegend( + [("NLD", "country:NLD")], + country_proj=lambda geom: affinity.scale( + geom, xfact=2.0, yfact=1.0, origin=(0, 0) + ), + add=False, + ) + w0 = np.ptp(handles0[0].get_path().vertices[:, 0]) + w1 = np.ptp(handles1[0].get_path().vertices[:, 0]) + assert w1 > w0 + + handles2, _ = ax.geolegend( + [("NLD", "country:NLD")], + add=False, + country_proj="platecarree", + ) + assert isinstance(handles2[0], mpatches.PathPatch) + + # Per-entry overrides via 3-tuples + handles3, labels3 = ax.geolegend( + [ + ("Base", "country:NLD"), + ( + "Wide", + "country:NLD", + { + "country_proj": lambda geom: affinity.scale( + geom, xfact=2.0, yfact=1.0, origin=(0, 0) + ) + }, + ), + ("StringProj", "country:NLD", "platecarree"), + ], + add=False, + ) + assert labels3 == ["Base", "Wide", "StringProj"] + w_base = np.ptp(handles3[0].get_path().vertices[:, 0]) + w_wide = np.ptp(handles3[1].get_path().vertices[:, 0]) + assert w_wide > w_base + uplt.close(fig) + + +def test_country_geometry_uses_dominant_component(): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + big = sgeom.box(4.0, 51.0, 7.0, 54.0) + tiny_far = sgeom.box(-69.0, 12.0, -68.8, 12.2) + geometry = sgeom.MultiPolygon([big, tiny_far]) + dominant = plegend._country_geometry_for_legend(geometry) + assert dominant.equals(big) + + +def test_country_geometry_keeps_nearby_islands(): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + mainland = sgeom.box(4.0, 51.0, 7.0, 54.0) + nearby_island = sgeom.box(5.0, 54.2, 5.2, 54.35) + far_island = sgeom.box(-69.0, 12.0, -68.8, 12.2) + geometry = sgeom.MultiPolygon([mainland, nearby_island, far_island]) + + reduced = plegend._country_geometry_for_legend(geometry) + geoms = list(getattr(reduced, "geoms", [reduced])) + assert any(part.equals(mainland) for part in geoms) + assert any(part.equals(nearby_island) for part in geoms) + assert not any(part.equals(far_island) for part in geoms) + + +def test_country_geometry_can_include_far_territories(): + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + mainland = sgeom.box(4.0, 51.0, 7.0, 54.0) + far_island = sgeom.box(-69.0, 12.0, -68.8, 12.2) + geometry = sgeom.MultiPolygon([mainland, far_island]) + kept = plegend._country_geometry_for_legend(geometry, include_far=True) + geoms = list(getattr(kept, "geoms", [kept])) + assert any(part.equals(mainland) for part in geoms) + assert any(part.equals(far_island) for part in geoms) + + +def test_geo_axes_add_geometries_auto_legend(): + ccrs = pytest.importorskip("cartopy.crs") + sgeom = pytest.importorskip("shapely.geometry") + + fig, ax = uplt.subplots(proj="cyl") + ax.add_geometries( + [sgeom.box(-20, -10, 20, 10)], + ccrs.PlateCarree(), + facecolor="blue7", + edgecolor="blue9", + label="Region", + ) + leg = ax.legend(loc="best") + labels = [text.get_text() for text in leg.get_texts()] + assert "Region" in labels + assert len(leg.legend_handles) == 1 + assert isinstance(leg.legend_handles[0], mpatches.PathPatch) + assert leg.legend_handles[0].get_joinstyle() == "bevel" + uplt.close(fig) + + +def test_geo_legend_defaults_to_bevel_joinstyle(): + fig, ax = uplt.subplots() + leg = ax.geolegend([("shape", "triangle")], loc="best") + assert isinstance(leg.legend_handles[0], mpatches.PathPatch) + assert leg.legend_handles[0].get_joinstyle() == "bevel" + uplt.close(fig) + + +def test_geo_legend_joinstyle_override(): + fig, ax = uplt.subplots() + leg = ax.geolegend([("shape", "triangle", {"joinstyle": "round"})], loc="best") + assert leg.legend_handles[0].get_joinstyle() == "round" + uplt.close(fig) + + +@pytest.mark.mpl_image_compare +def test_semantic_legends_showcase_smoke(monkeypatch): + """ + End-to-end smoke test showing semantic legend helpers in one figure: + categorical, size, numeric-color, and geometry (generic + country shorthands). + """ + sgeom = pytest.importorskip("shapely.geometry") + from ultraplot import legend as plegend + + # Prefer real Natural Earth country geometries if available. In offline CI, + # fall back to deterministic local geometries while still exercising shorthand. + country_entries = [("Australia", "country:AU"), ("New Zealand", "country:NZ")] + uses_real_countries = True + try: + fig_tmp, ax_tmp = uplt.subplots() + ax_tmp.geolegend( + country_entries, edgecolor="black", facecolor="none", add=False + ) + uplt.close(fig_tmp) + except ValueError: + uses_real_countries = False + country_geoms = { + "AU": sgeom.box(110, -45, 155, -10), + "NZ": sgeom.box(166, -48, 179, -34), + } + + def _fake_country(code): + key = str(code).upper() + if key not in country_geoms: + raise ValueError(f"Unknown shorthand in test: {code!r}") + return country_geoms[key] + + monkeypatch.setattr(plegend, "_resolve_country_geometry", _fake_country) + + fig, axs = uplt.subplots(ncols=2, nrows=2, share=False) + + leg = axs[0].catlegend( + ["A", "B", "C"], + colors={"A": "red7", "B": "green7", "C": "blue7"}, + markers={"A": "o", "B": "s", "C": "^"}, + loc="best", + title="catlegend", + ) + assert [text.get_text() for text in leg.get_texts()] == ["A", "B", "C"] + + leg = axs[1].sizelegend( + [10, 50, 200], color="gray6", loc="best", title="sizelegend" + ) + assert [text.get_text() for text in leg.get_texts()] == ["10", "50", "200"] + + leg = axs[2].numlegend( + vmin=0.0, + vmax=1.0, + n=4, + cmap="viridis", + fmt="{:.2f}", + loc="best", + title="numlegend", + ) + assert len(leg.legend_handles) == 4 + assert all(isinstance(handle, mpatches.Patch) for handle in leg.legend_handles) + + handles, labels = axs[3].geolegend( + [ + ("Triangle", "triangle"), + ("Hexagon", "hexagon"), + *country_entries, + ], + edgecolor="black", + facecolor="none", + add=False, + ) + leg = axs[3].legend(handles, labels, loc="best", title="geolegend") + legend_labels = [text.get_text() for text in leg.get_texts()] + assert set(legend_labels) == set(labels) + assert len(legend_labels) == len(labels) + assert all(isinstance(handle, mpatches.PathPatch) for handle in leg.legend_handles) + if uses_real_countries: + # Real shorthand resolution succeeded (no monkeypatched fallback). + assert {"Australia", "New Zealand"}.issubset(set(legend_labels)) + return fig + + def test_pie_legend_uses_wedge_handles(): fig, ax = uplt.subplots() wedges, _ = ax.pie([30, 70], labels=["a", "b"])