diff --git a/src/spatialdata_plot/pl/render.py b/src/spatialdata_plot/pl/render.py index 52c8b773..6ada5397 100644 --- a/src/spatialdata_plot/pl/render.py +++ b/src/spatialdata_plot/pl/render.py @@ -9,6 +9,7 @@ import geopandas as gpd import matplotlib import matplotlib.pyplot as plt +import matplotlib.ticker import numpy as np import pandas as pd import scanpy as sc @@ -141,6 +142,15 @@ def _render_shapes( color_source_vector = color_source_vector[mask] color_vector = color_vector[mask] + # continuous case: leave NaNs as NaNs; utils maps them to na_color during draw + if color_source_vector is None and not values_are_categorical: + color_vector = np.asarray(color_vector, dtype=float) + if np.isnan(color_vector).any(): + nan_count = int(np.isnan(color_vector).sum()) + logger.warning( + f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'." + ) + # Using dict.fromkeys here since set returns in arbitrary order # remove the color of NaN values, else it might be assigned to a category # order of color in the palette should agree to order of occurence @@ -195,7 +205,10 @@ def _render_shapes( # Handle circles encoded as points with radius if is_point.any(): - scale = shapes[is_point]["radius"] * render_params.scale + radius_values = shapes[is_point]["radius"] + # Convert to numeric, replacing non-numeric values with NaN + radius_numeric = pd.to_numeric(radius_values, errors="coerce") + scale = radius_numeric * render_params.scale shapes.loc[is_point, "geometry"] = _geometry[is_point].buffer(scale.to_numpy()) # apply transformations to the individual points @@ -218,6 +231,20 @@ def _render_shapes( # in case we are coloring by a column in table if col_for_color is not None and col_for_color not in transformed_element.columns: + # Ensure color vector length matches the number of shapes + if len(color_vector) != len(transformed_element): + if len(color_vector) == 1: + # If single color, broadcast to all shapes + color_vector = [color_vector[0]] * len(transformed_element) + else: + # If lengths don't match, pad or truncate to match + if len(color_vector) > len(transformed_element): + color_vector = color_vector[: len(transformed_element)] + else: + # Pad with the last color or na_color + na_color = render_params.cmap_params.na_color.get_hex_with_alpha() + color_vector = list(color_vector) + [na_color] * (len(transformed_element) - len(color_vector)) + transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector # Render shapes with datashader color_by_categorical = col_for_color is not None and color_source_vector is not None @@ -447,12 +474,13 @@ def _render_shapes( path.vertices = trans.transform(path.vertices) if not values_are_categorical: - # If the user passed a Normalize object with vmin/vmax we'll use those, - # if not we'll use the min/max of the color_vector - _cax.set_clim( - vmin=render_params.cmap_params.norm.vmin or min(color_vector), - vmax=render_params.cmap_params.norm.vmax or max(color_vector), - ) + vmin = render_params.cmap_params.norm.vmin + vmax = render_params.cmap_params.norm.vmax + if vmin is None: + vmin = float(np.nanmin(color_vector)) + if vmax is None: + vmax = float(np.nanmax(color_vector)) + _cax.set_clim(vmin=vmin, vmax=vmax) if ( len(set(color_vector)) != 1 diff --git a/src/spatialdata_plot/pl/utils.py b/src/spatialdata_plot/pl/utils.py index 82b84973..9ef5bcfd 100644 --- a/src/spatialdata_plot/pl/utils.py +++ b/src/spatialdata_plot/pl/utils.py @@ -17,6 +17,7 @@ import matplotlib.patches as mpatches import matplotlib.path as mpath import matplotlib.pyplot as plt +import matplotlib.ticker import matplotlib.transforms as mtransforms import numpy as np import numpy.ma as ma @@ -24,7 +25,6 @@ import pandas as pd import shapely import spatialdata as sd -import xarray as xr from anndata import AnnData from cycler import Cycler, cycler from datashader.core import Canvas @@ -94,6 +94,50 @@ ColorLike = tuple[float, ...] | list[float] | str +def _extract_scalar_value(value: Any, default: float = 0.0) -> float: + """ + Extract a scalar float value from various data types. + + Handles pandas Series, arrays, lists, and other iterables by taking the first element. + Converts non-numeric values to the default value. + + Parameters + ---------- + value : Any + The value to extract a scalar from + default : float, default 0.0 + Default value to return if conversion fails + + Returns + ------- + float + The extracted scalar value + """ + try: + # Handle pandas Series or similar objects with iloc + if hasattr(value, "iloc"): + if len(value) > 0: + value = value.iloc[0] + else: + return default + + # Handle other array-like objects + elif hasattr(value, "__len__") and not isinstance(value, (str, bytes)): + if len(value) > 0: + value = value[0] + else: + return default + + # Convert to float, handling NaN values + if pd.isna(value): + return default + + return float(value) + + except (TypeError, ValueError, IndexError): + return default + + def _verify_plotting_tree(sdata: SpatialData) -> SpatialData: """Verify that the plotting tree exists, and if not, create it.""" if not hasattr(sdata, "plotting_tree"): @@ -286,9 +330,10 @@ def _get_centroid_of_pathpatch(pathpatch: mpatches.PathPatch) -> tuple[float, fl def _scale_pathpatch_around_centroid(pathpatch: mpatches.PathPatch, scale_factor: float) -> None: + scale_value = _extract_scalar_value(scale_factor, default=1.0) centroid = _get_centroid_of_pathpatch(pathpatch) vertices = pathpatch.get_path().vertices - scaled_vertices = np.array([centroid + (vertex - centroid) * scale_factor for vertex in vertices]) + scaled_vertices = np.array([centroid + (vertex - centroid) * scale_value for vertex in vertices]) pathpatch.get_path().vertices = scaled_vertices @@ -305,140 +350,173 @@ def _get_collection_shape( **kwargs: Any, ) -> PatchCollection: """ - Get a PatchCollection for rendering given geometries with specified colors and outlines. - - Args: - - shapes (list[GeoDataFrame]): List of geometrical shapes. - - c: Color parameter. - - s (float): Scale of the shape. - - norm: Normalization for the color map. - - fill_alpha (float, optional): Opacity for the fill color. - - outline_alpha (float, optional): Opacity for the outline. - - outline_color (optional): Color for the outline. - - linewidth (float, optional): Width for the outline. - - **kwargs: Additional keyword arguments. + Build a PatchCollection for shapes with correct handling of. - Returns - ------- - - PatchCollection: Collection of patches for rendering. + - continuous numeric vectors with NaNs, + - per-row RGBA arrays, + - a single color or a list of color specs. + + Only NaNs are painted with na_color; finite values are mapped via norm+cmap. """ cmap = kwargs["cmap"] - try: - # fails when numeric - if len(c.shape) == 1 and c.shape[0] in [3, 4] and c.shape[0] == len(shapes) and c.dtype == float: - if norm is None: - c = cmap(c) - else: - try: - norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm - except ValueError as e: - raise ValueError( - "Could not convert values in the `color` column to float, if `color` column represents" - " categories, set the column to categorical dtype." - ) from e - c = cmap(norm(c)) - else: - fill_c = ColorConverter().to_rgba_array(c) - except ValueError: - if norm is None: - c = cmap(c) + # Resolve na color once + na_rgba = colors.to_rgba(render_params.cmap_params.na_color.get_hex_with_alpha()) + + # Try to interpret c as numpy array + c_arr = np.asarray(c) + fill_c: np.ndarray + + def _as_rgba_array(x: Any) -> np.ndarray: + return np.asarray(ColorConverter().to_rgba_array(x)) + + # Case A: per-row numeric colors given as Nx3 or Nx4 float array + if ( + c_arr.ndim == 2 + and c_arr.shape[0] == len(shapes) + and c_arr.shape[1] in (3, 4) + and np.issubdtype(c_arr.dtype, np.number) + ): + fill_c = _as_rgba_array(c_arr) + + # Case B: continuous numeric vector len == n_shapes (possibly with NaNs) + elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and np.issubdtype(c_arr.dtype, np.number): + finite_mask = np.isfinite(c_arr) + + # Select or build a normalization that ignores NaNs for scaling + if isinstance(norm, Normalize): + used_norm: Normalize = norm else: - try: - norm = colors.Normalize(vmin=min(c), vmax=max(c)) if norm is None else norm - except ValueError as e: - raise ValueError( - "Could not convert values in the `color` column to float, if `color` column represents" - " categories, set the column to categorical dtype." - ) from e - c = cmap(norm(c)) + if finite_mask.any(): + vmin = float(np.nanmin(c_arr[finite_mask])) + vmax = float(np.nanmax(c_arr[finite_mask])) + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + else: + vmin, vmax = 0.0, 1.0 + used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) + + # Map finite values through cmap(norm(.)); NaNs get na_color + fill_c = np.empty((len(c_arr), 4), dtype=float) + fill_c[:] = na_rgba + if finite_mask.any(): + fill_c[finite_mask] = cmap(used_norm(c_arr[finite_mask])) + + elif c_arr.ndim == 1 and len(c_arr) == len(shapes) and c_arr.dtype == object: + # Split into numeric vs color-like + c_series = pd.Series(c_arr, copy=False) + num = pd.to_numeric(c_series, errors="coerce").to_numpy() + is_num = np.isfinite(num) + + # init with na color + fill_c = np.empty((len(c_series), 4), dtype=float) + fill_c[:] = na_rgba + + # numeric entries via cmap(norm) + if is_num.any(): + if isinstance(norm, Normalize): + used_norm = norm + else: + vmin = float(np.nanmin(num[is_num])) if is_num.any() else 0.0 + vmax = float(np.nanmax(num[is_num])) if is_num.any() else 1.0 + if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax: + vmin, vmax = 0.0, 1.0 + used_norm = colors.Normalize(vmin=vmin, vmax=vmax, clip=False) + fill_c[is_num] = cmap(used_norm(num[is_num])) + + # non-numeric entries as explicit colors + if (~is_num).any(): + fill_c[~is_num] = ColorConverter().to_rgba_array(c_series[~is_num].tolist()) + + # Case C: single color or list of color-like specs (strings or tuples) + else: + fill_c = _as_rgba_array(c) - fill_c = ColorConverter().to_rgba_array(c) - # fill_c[..., -1] *= fill_alpha # NOTE: this contradicts matplotlib behavior, therefore discarded + # Apply optional fill alpha without destroying existing transparency if fill_alpha is not None: - fill_c[..., -1] = fill_alpha + nonzero_alpha = fill_c[..., -1] > 0 + fill_c[nonzero_alpha, -1] = fill_alpha + # Outline handling if outline_alpha and outline_alpha > 0.0: - outline_c = ColorConverter().to_rgba_array(outline_color) - outline_c[..., -1] = outline_alpha - outline_c = outline_c.tolist() + outline_c_array = _as_rgba_array(outline_color) + outline_c_array[..., -1] = outline_alpha + outline_c = outline_c_array.tolist() else: - outline_c = [None] - outline_c = outline_c * fill_c.shape[0] + outline_c = [None] * fill_c.shape[0] + # Build DataFrame of valid geometries shapes_df = pd.DataFrame(shapes, copy=True) shapes_df = shapes_df[shapes_df["geometry"].apply(lambda geom: not geom.is_empty)] shapes_df = shapes_df.reset_index(drop=True) def _assign_fill_and_outline_to_row( - fill_c: list[Any], - outline_c: list[Any], + fill_colors: list[Any], + outline_colors: list[Any], row: dict[str, Any], idx: int, is_multiple_shapes: bool, ) -> None: - try: - if is_multiple_shapes and len(fill_c) == 1: - row["fill_c"] = fill_c[0] - row["outline_c"] = outline_c[0] - else: - row["fill_c"] = fill_c[idx] - row["outline_c"] = outline_c[idx] - except IndexError as e: - raise IndexError("Could not assign fill and outline colors due to a mismatch in row numbers.") from e + if is_multiple_shapes and len(fill_colors) == 1: + row["fill_c"] = fill_colors[0] + row["outline_c"] = outline_colors[0] + else: + row["fill_c"] = fill_colors[idx] + row["outline_c"] = outline_colors[idx] - def _process_polygon(row: pd.Series, s: float) -> dict[str, Any]: + def _process_polygon(row: pd.Series, scale: float) -> dict[str, Any]: coords = np.array(row["geometry"].exterior.coords) centroid = np.mean(coords, axis=0) - scaled_vectors = (coords - centroid) * s - scaled_coords = (centroid + scaled_vectors).tolist() - return { - **row.to_dict(), - "geometry": mpatches.Polygon(scaled_coords, closed=True), - } + scale_value = _extract_scalar_value(scale, default=1.0) + scaled = (centroid + (coords - centroid) * scale_value).tolist() + return {**row.to_dict(), "geometry": mpatches.Polygon(scaled, closed=True)} - def _process_multipolygon(row: pd.Series, s: float) -> list[dict[str, Any]]: + def _process_multipolygon(row: pd.Series, scale: float) -> list[dict[str, Any]]: mp = _make_patch_from_multipolygon(row["geometry"]) row_dict = row.to_dict() for m in mp: - _scale_pathpatch_around_centroid(m, s) - + _scale_pathpatch_around_centroid(m, scale) return [{**row_dict, "geometry": m} for m in mp] - def _process_point(row: pd.Series, s: float) -> dict[str, Any]: + def _process_point(row: pd.Series, scale: float) -> dict[str, Any]: + radius_value = _extract_scalar_value(row["radius"], default=0.0) + scale_value = _extract_scalar_value(scale, default=1.0) + radius = radius_value * scale_value + return { **row.to_dict(), - "geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=row["radius"] * s), + "geometry": mpatches.Circle((row["geometry"].x, row["geometry"].y), radius=radius), } - def _create_patches(shapes_df: GeoDataFrame, fill_c: list[Any], outline_c: list[Any], s: float) -> pd.DataFrame: - rows = [] - is_multiple_shapes = len(shapes_df) > 1 - - for idx, row in shapes_df.iterrows(): + def _create_patches( + shapes_df_: GeoDataFrame, fill_colors: list[Any], outline_colors: list[Any], scale: float + ) -> pd.DataFrame: + rows: list[dict[str, Any]] = [] + is_multiple = len(shapes_df_) > 1 + for idx, row in shapes_df_.iterrows(): geom_type = row["geometry"].geom_type - processed_rows = [] - + processed: list[dict[str, Any]] = [] if geom_type == "Polygon": - processed_rows.append(_process_polygon(row, s)) + processed.append(_process_polygon(row, scale)) elif geom_type == "MultiPolygon": - processed_rows.extend(_process_multipolygon(row, s)) + processed.extend(_process_multipolygon(row, scale)) elif geom_type == "Point": - processed_rows.append(_process_point(row, s)) - - for processed_row in processed_rows: - _assign_fill_and_outline_to_row(fill_c, outline_c, processed_row, idx, is_multiple_shapes) - rows.append(processed_row) - + processed.append(_process_point(row, scale)) + for pr in processed: + _assign_fill_and_outline_to_row(fill_colors, outline_colors, pr, idx, is_multiple) + rows.append(pr) return pd.DataFrame(rows) - patches = _create_patches(shapes_df, fill_c, outline_c, s) + patches = _create_patches( + shapes_df, fill_c.tolist(), outline_c.tolist() if hasattr(outline_c, "tolist") else outline_c, s + ) + return PatchCollection( patches["geometry"].values.tolist(), snap=False, lw=linewidth, facecolor=patches["fill_c"], - edgecolor=None if all(outline is None for outline in outline_c) else outline_c, + edgecolor=None if all(o is None for o in outline_c) else outline_c, **kwargs, ) @@ -651,57 +729,6 @@ def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int = return fig, axes -def _normalize( - img: xr.DataArray, - pmin: float | None = None, - pmax: float | None = None, - eps: float = 1e-20, - clip: bool = False, - name: str = "normed", -) -> xr.DataArray: - """Perform a min max normalisation on the xr.DataArray. - - This function was adapted from the csbdeep package. - - Parameters - ---------- - dataarray - A xarray DataArray with an image field. - pmin - Lower quantile (min value) used to perform quantile normalization. - pmax - Upper quantile (max value) used to perform quantile normalization. - eps - Epsilon float added to prevent 0 division. - clip - Ensures that normed image array contains no values greater than 1. - - Returns - ------- - xr.DataArray - A min-max normalized image. - """ - pmin = pmin or 0.0 - pmax = pmax or 100.0 - - perc = np.percentile(img, [pmin, pmax]) - - # Ensure perc is an array of two elements - if np.isscalar(perc): - logger.warning( - "Percentile range is too small, using the same percentile for both min " - "and max. Consider using a larger percentile range." - ) - perc = np.array([perc, perc]) - - norm = (img - perc[0]) / (perc[1] - perc[0] + eps) # type: ignore - - if clip: - norm = np.clip(norm, 0, 1) - - return norm - - def _get_colors_for_categorical_obs( categories: Sequence[str | int], palette: ListedColormap | str | list[str] | None = None, @@ -1220,13 +1247,6 @@ def _get_linear_colormap(colors: list[str], background: str) -> list[LinearSegme return [LinearSegmentedColormap.from_list(c, [background, c], N=256) for c in colors] -def _get_listed_colormap(color_dict: dict[str, str]) -> ListedColormap: - sorted_labels = sorted(color_dict.keys()) - colors = [color_dict[k] for k in sorted_labels] - - return ListedColormap(["black"] + colors, N=len(colors) + 1) - - def _split_multipolygon_into_outer_and_inner(mp: shapely.MultiPolygon): # type: ignore # https://stackoverflow.com/a/21922058 @@ -1802,9 +1822,7 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st if (norm := param_dict.get("norm")) is not None: if element_type in {"images", "labels"} and not isinstance(norm, Normalize): raise TypeError("Parameter 'norm' must be of type Normalize.") - if element_type in {"shapes", "points"} and not isinstance( - norm, bool | Normalize - ): + if element_type in {"shapes", "points"} and not isinstance(norm, bool | Normalize): raise TypeError("Parameter 'norm' must be a boolean or a mpl.Normalize.") if (scale := param_dict.get("scale")) is not None: @@ -1823,15 +1841,11 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st raise ValueError("Parameter 'size' must be a positive number.") if element_type == "shapes" and (shape := param_dict.get("shape")) is not None: - valid_shapes = {"circle", "hex", "visium_hex", "square"} + valid_shapes = {"circle", "hex", "visium_hex", "square"} if not isinstance(shape, str): - raise TypeError( - f"Parameter 'shape' must be a String from {valid_shapes} if not None." - ) + raise TypeError(f"Parameter 'shape' must be a String from {valid_shapes} if not None.") if shape not in valid_shapes: - raise ValueError( - f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}." - ) + raise ValueError(f"'{shape}' is not supported for 'shape', please choose from {valid_shapes}.") table_name = param_dict.get("table_name") table_layer = param_dict.get("table_layer") @@ -2505,39 +2519,6 @@ def _prepare_transformation( return trans, trans_data -def _get_datashader_trans_matrix_of_single_element( - trans: Identity | Scale | Affine | MapAxis | Translation, -) -> ArrayLike: - flip_matrix = np.array([[1, 0, 0], [0, -1, 0], [0, 0, 1]]) - tm: ArrayLike = trans.to_affine_matrix(("x", "y"), ("x", "y")) - - if isinstance(trans, Identity): - return np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - if isinstance(trans, (Scale | Affine)): - # idea: "flip the y-axis", apply transformation, flip back - flip_and_transform: ArrayLike = flip_matrix @ tm @ flip_matrix - return flip_and_transform - if isinstance(trans, MapAxis): - # no flipping needed - return tm - # for a Translation, we need the transposed transformation matrix - tm_T = tm.T - assert isinstance(tm_T, np.ndarray) - return tm_T - - -def _get_transformation_matrix_for_datashader( - trans: Scale | Identity | Affine | MapAxis | Translation | SDSequence, -) -> ArrayLike: - """Get the affine matrix needed to transform shapes for rendering with datashader.""" - if isinstance(trans, SDSequence): - tm = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - for x in trans.transformations: - tm = tm @ _get_datashader_trans_matrix_of_single_element(x) - return tm - return _get_datashader_trans_matrix_of_single_element(trans) - - def _datashader_map_aggregate_to_color( agg: DataArray, cmap: str | list[str] | ListedColormap, @@ -2640,175 +2621,182 @@ def _hex_no_alpha(hex: str) -> str: def _convert_shapes( - shapes: GeoDataFrame, target_shape: str, max_extent: float, warn_above_extent_fraction: float = 0.5 + shapes: GeoDataFrame, + target_shape: str, + max_extent: float, + warn_above_extent_fraction: float = 0.5, ) -> GeoDataFrame: - """Convert the shapes stored in a GeoDataFrame (geometry column) to the target_shape.""" - # NOTE: possible follow-up: when converting equally sized shapes to hex, automatically scale resulting hexagons - # so that they are perfectly adjacent to each other - + """Convert shapes in a GeoDataFrame to the target_shape, using positional indexing.""" if warn_above_extent_fraction < 0.0 or warn_above_extent_fraction > 1.0: - warn_above_extent_fraction = 0.5 # set to default if the value is outside [0, 1] + warn_above_extent_fraction = 0.5 warn_shape_size = False - # define individual conversion methods + # work on a copy with a clean positional index + shapes = shapes.reset_index(drop=True).copy() + def _circle_to_hexagon(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - # Create hexagon with point at top (30° offset from standard orientation) - vertices = [ - (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) - for angle in range(30, 390, 60) # Start at 30° and go every 60° + verts = [ + ( + center.x + radius * math.cos(math.radians(a)), + center.y + radius * math.sin(math.radians(a)), + ) + for a in range(30, 390, 60) ] - return shapely.Polygon(vertices), None + return shapely.Polygon(verts), None def _circle_to_square(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - vertices = [ - (center.x + radius * math.cos(math.radians(angle)), center.y + radius * math.sin(math.radians(angle))) - for angle in range(45, 360, 90) + verts = [ + ( + center.x + radius * math.cos(math.radians(a)), + center.y + radius * math.sin(math.radians(a)), + ) + for a in range(45, 360, 90) ] - return shapely.Polygon(vertices), None + return shapely.Polygon(verts), None def _circle_to_circle(center: shapely.Point, radius: float) -> tuple[shapely.Point, float]: return center, radius - def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - center, radius = _polygon_to_circle(polygon) - return _circle_to_hexagon(center, radius) - - def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - center, radius = _polygon_to_circle(polygon) - return _circle_to_square(center, radius) - def _polygon_to_circle(polygon: shapely.Polygon) -> tuple[shapely.Point, float]: coords = np.array(polygon.exterior.coords) - circle_points = coords[ConvexHull(coords).vertices] - center = np.mean(circle_points, axis=0) - radius = max(float(np.linalg.norm(p - center)) for p in circle_points) - assert isinstance(radius, float) # shut up mypy + hull_pts = coords[ConvexHull(coords).vertices] + center = np.mean(hull_pts, axis=0) + radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + nonlocal warn_shape_size if 2 * radius > max_extent * warn_above_extent_fraction: - nonlocal warn_shape_size warn_shape_size = True return shapely.Point(center), radius - def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - center, radius = _multipolygon_to_circle(multipolygon) - return _circle_to_hexagon(center, radius) + def _polygon_to_hexagon(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + c, r = _polygon_to_circle(polygon) + return _circle_to_hexagon(c, r) - def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - center, radius = _multipolygon_to_circle(multipolygon) - return _circle_to_square(center, radius) + def _polygon_to_square(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + c, r = _polygon_to_circle(polygon) + return _circle_to_square(c, r) def _multipolygon_to_circle(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Point, float]: - coords = [] - for polygon in multipolygon.geoms: - coords.extend(polygon.exterior.coords) - points = np.array(coords) - circle_points = points[ConvexHull(points).vertices] - center = np.mean(circle_points, axis=0) - radius = max(float(np.linalg.norm(p - center)) for p in circle_points) - assert isinstance(radius, float) # shut up mypy + pts = [] + for poly in multipolygon.geoms: + pts.extend(poly.exterior.coords) + pts_array = np.array(pts) + hull_pts = pts_array[ConvexHull(pts_array).vertices] + center = np.mean(hull_pts, axis=0) + radius = float(np.max(np.linalg.norm(hull_pts - center, axis=1))) + nonlocal warn_shape_size if 2 * radius > max_extent * warn_above_extent_fraction: - nonlocal warn_shape_size warn_shape_size = True return shapely.Point(center), radius - # define dict with all conversion methods + def _multipolygon_to_hexagon(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + c, r = _multipolygon_to_circle(multipolygon) + return _circle_to_hexagon(c, r) + + def _multipolygon_to_square(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + c, r = _multipolygon_to_circle(multipolygon) + return _circle_to_square(c, r) + + # choose conversion methods + conversion_methods: dict[str, Any] if target_shape == "circle": conversion_methods = { "Point": _circle_to_circle, "Polygon": _polygon_to_circle, - "Multipolygon": _multipolygon_to_circle, + "MultiPolygon": _multipolygon_to_circle, } - pass elif target_shape == "hex": conversion_methods = { "Point": _circle_to_hexagon, "Polygon": _polygon_to_hexagon, - "Multipolygon": _multipolygon_to_hexagon, + "MultiPolygon": _multipolygon_to_hexagon, } elif target_shape == "visium_hex": - # For visium_hex, we only support Points and warn for other geometry types + # estimate hex radius from point spacing when possible point_centers = [] non_point_count = 0 - - for i in range(shapes.shape[0]): - if shapes["geometry"][i].type == "Point": - point_centers.append((shapes["geometry"][i].x, shapes["geometry"][i].y)) + for geom in shapes.geometry: + if geom.geom_type == "Point": + point_centers.append((geom.x, geom.y)) else: non_point_count += 1 - if non_point_count > 0: warnings.warn( - f"visium_hex conversion only supports Point geometries. Found {non_point_count} non-Point geometries " - f"that will be converted using regular hex conversion. Consider using shape='hex' for mixed geometry types.", + "visium_hex supports Points best. Non-Point geometries will use regular hex conversion.", UserWarning, stacklevel=2, ) + if len(point_centers) >= 2: + centers = np.array(point_centers, dtype=float) + # pairwise min distance + dmin = np.inf + for i in range(len(centers)): + diffs = centers[i + 1 :] - centers[i] + if diffs.size: + d = np.min(np.linalg.norm(diffs, axis=1)) + dmin = min(dmin, d) + if not np.isfinite(dmin) or dmin <= 0: + # fallback + conversion_methods = { + "Point": _circle_to_hexagon, + "Polygon": _polygon_to_hexagon, + "MultiPolygon": _multipolygon_to_hexagon, + } + else: + hex_radius = dmin / math.sqrt(3.0) + + def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: + return _circle_to_hexagon(center, hex_radius) + + def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: + return _polygon_to_hexagon(polygon) + + def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: + return _multipolygon_to_hexagon(multipolygon) - if len(point_centers) < 2: - # If we have fewer than 2 points, fall back to regular hex conversion + conversion_methods = { + "Point": _circle_to_visium_hex, + "Polygon": _polygon_to_visium_hex, + "MultiPolygon": _multipolygon_to_visium_hex, + } + else: conversion_methods = { "Point": _circle_to_hexagon, "Polygon": _polygon_to_hexagon, - "Multipolygon": _multipolygon_to_hexagon, - } - else: - # Calculate typical spacing between point centers - centers_array = np.array(point_centers) - distances = [] - for i in range(len(point_centers)): - for j in range(i + 1, len(point_centers)): - dist = np.linalg.norm(centers_array[i] - centers_array[j]) - distances.append(dist) - - # Use min dist of closest neighbors as the side length for radius calc - side_length = np.min(distances) - hex_radius = (side_length * 2.0 / math.sqrt(3)) / 2.0 - - # Create conversion methods - def _circle_to_visium_hex(center: shapely.Point, radius: float) -> tuple[shapely.Polygon, None]: - return _circle_to_hexagon(center, hex_radius) - - def _polygon_to_visium_hex(polygon: shapely.Polygon) -> tuple[shapely.Polygon, None]: - # Fall back to regular hex conversion for non-points - return _polygon_to_hexagon(polygon) - - def _multipolygon_to_visium_hex(multipolygon: shapely.MultiPolygon) -> tuple[shapely.Polygon, None]: - # Fall back to regular hex conversion for non-points - return _multipolygon_to_hexagon(multipolygon) - - conversion_methods = { - "Point": _circle_to_visium_hex, - "Polygon": _polygon_to_visium_hex, - "Multipolygon": _multipolygon_to_visium_hex, + "MultiPolygon": _multipolygon_to_hexagon, } else: conversion_methods = { "Point": _circle_to_square, "Polygon": _polygon_to_square, - "Multipolygon": _multipolygon_to_square, + "MultiPolygon": _multipolygon_to_square, } - # convert every shape - for i in range(shapes.shape[0]): - if shapes["geometry"][i].type == "Point": - converted, radius = conversion_methods["Point"](shapes["geometry"][i], shapes["radius"][i]) # type: ignore - elif shapes["geometry"][i].type == "Polygon": - converted, radius = conversion_methods["Polygon"](shapes["geometry"][i]) # type: ignore - elif shapes["geometry"][i].type == "MultiPolygon": - converted, radius = conversion_methods["Multipolygon"](shapes["geometry"][i]) # type: ignore + # ensure radius column exists if needed + if "radius" not in shapes.columns: + shapes["radius"] = np.nan + + # convert all geometries using positional indexing + for i in range(len(shapes)): + geom = shapes.geometry.iloc[i] + gtype = geom.geom_type + if gtype == "Point": + r = shapes["radius"].iloc[i] + r = float(r) if np.isfinite(r) else 0.0 + converted, radius = conversion_methods["Point"](geom, r) # type: ignore[arg-type] + elif gtype == "Polygon": + converted, radius = conversion_methods["Polygon"](geom) # type: ignore[arg-type] + elif gtype == "MultiPolygon": + converted, radius = conversion_methods["MultiPolygon"](geom) # type: ignore[arg-type] else: - error_type = shapes["geometry"][i].type - raise ValueError(f"Converting shape {error_type} to {target_shape} is not supported.") - shapes["geometry"][i] = converted + raise ValueError(f"Converting shape {gtype} to {target_shape} is not supported.") + shapes.at[i, "geometry"] = converted if radius is not None: - if "radius" not in shapes.columns: - shapes["radius"] = np.nan - shapes["radius"][i] = radius + shapes.at[i, "radius"] = radius if warn_shape_size: logger.info( - f"When converting the shapes, the size of at least one target shape extends " - f"{warn_above_extent_fraction * 100}% of the original total bound of the shapes. The conversion" - " might not give satisfying results in this scenario." + f"At least one converted shape spans >= {warn_above_extent_fraction * 100:.0f}% of the " + "original total bound. Results may be suboptimal." ) return shapes diff --git a/tests/_images/Shapes_colorbar_can_be_normalised.png b/tests/_images/Shapes_colorbar_can_be_normalised.png index 7b095252..f18723e8 100644 Binary files a/tests/_images/Shapes_colorbar_can_be_normalised.png and b/tests/_images/Shapes_colorbar_can_be_normalised.png differ diff --git a/tests/conftest.py b/tests/conftest.py index 61c66573..b44d6d8e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -551,4 +551,4 @@ def sdata_hexagonal_grid_spots(): # Use ShapesModel.parse() to create a properly validated GeoDataFrame shapes_gdf = ShapesModel.parse(gdf) - return SpatialData(shapes={"spots": shapes_gdf}) \ No newline at end of file + return SpatialData(shapes={"spots": shapes_gdf}) diff --git a/tests/pl/test_render_shapes.py b/tests/pl/test_render_shapes.py index 1a6dc3bb..e8b301a1 100644 --- a/tests/pl/test_render_shapes.py +++ b/tests/pl/test_render_shapes.py @@ -684,3 +684,40 @@ def test_warns_when_table_does_not_annotate_element(sdata_blobs: SpatialData): table_name="other_table", ).pl.show() ) + + def test_plot_can_handle_nan_values_in_color_data(self, sdata_blobs: SpatialData): + """Test that NaN values in color data are handled gracefully.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" + + # Add color column with NaN values + sdata_blobs.shapes["blobs_circles"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] + + # Test that rendering works with NaN values and issues warning + with pytest.warns(UserWarning, match="Found 1 NaN values in color data"): + sdata_blobs.pl.render_shapes(element="blobs_circles", color="color_with_nan", na_color="red").pl.show() + + def test_plot_colorbar_normalization_with_nan_values(self, sdata_blobs: SpatialData): + """Test that colorbar normalization works correctly with NaN values.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_polygons"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_polygons" + + sdata_blobs.shapes["blobs_polygons"]["color_with_nan"] = [1.0, 2.0, np.nan, 4.0, 5.0] + + # Test colorbar with NaN values - should use nanmin/nanmax + sdata_blobs.pl.render_shapes(element="blobs_polygons", color="color_with_nan", na_color="gray").pl.show() + + def test_plot_can_handle_non_numeric_radius_values(self, sdata_blobs: SpatialData): + """Test that non-numeric radius values are handled gracefully.""" + sdata_blobs.shapes["blobs_circles"]["radius_mixed"] = [1.0, "invalid", 3.0, np.nan, 5.0] + + sdata_blobs.pl.render_shapes(element="blobs_circles", color="red").pl.show() + + def test_plot_can_handle_mixed_numeric_and_color_data(self, sdata_blobs: SpatialData): + """Test handling of mixed numeric and color-like data.""" + sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs) + sdata_blobs["table"].uns["spatialdata_attrs"]["region"] = "blobs_circles" + + sdata_blobs.shapes["blobs_circles"]["mixed_data"] = [1.0, 2.0, np.nan, "red", 5.0] + + sdata_blobs.pl.render_shapes(element="blobs_circles", color="mixed_data", na_color="gray").pl.show() diff --git a/tests/pl/test_utils.py b/tests/pl/test_utils.py index 0eef85e3..a9296d2e 100644 --- a/tests/pl/test_utils.py +++ b/tests/pl/test_utils.py @@ -1,5 +1,6 @@ import matplotlib import matplotlib.pyplot as plt +import numpy as np import pandas as pd import pytest import scanpy as sc @@ -89,6 +90,37 @@ def test_is_color_like(color_result: tuple[ColorLike, bool]): assert spatialdata_plot.pl.utils._is_color_like(color) == result +def test_extract_scalar_value(): + """Test the new _extract_scalar_value function for robust numeric conversion.""" + + from spatialdata_plot.pl.utils import _extract_scalar_value + + # Test basic functionality + assert _extract_scalar_value(3.14) == 3.14 + assert _extract_scalar_value(42) == 42.0 + + # Test with collections + assert _extract_scalar_value(pd.Series([1.0, 2.0, 3.0])) == 1.0 + assert _extract_scalar_value([1.0, 2.0, 3.0]) == 1.0 + + # Test edge cases + assert _extract_scalar_value(np.nan) == 0.0 + assert _extract_scalar_value("invalid") == 0.0 + assert _extract_scalar_value([], default=1.0) == 1.0 + + +def test_plot_can_handle_rgba_color_specifications(sdata_blobs: SpatialData): + """Test handling of RGBA color specifications.""" + # Test with RGBA tuple + sdata_blobs.pl.render_shapes(element="blobs_circles", color=(1.0, 0.0, 0.0, 0.8)).pl.show() + + # Test with RGB tuple (no alpha) + sdata_blobs.pl.render_shapes(element="blobs_circles", color=(0.0, 1.0, 0.0)).pl.show() + + # Test with string color + sdata_blobs.pl.render_shapes(element="blobs_circles", color="blue").pl.show() + + @pytest.mark.parametrize( "input_output", [