Skip to content

Commit 1bd9237

Browse files
author
Sonja Stockhaus
committed
fix introduced mpl shapes cbar bug, ds continuous case: render nan points/shapes separately before colored
1 parent edbcb88 commit 1bd9237

File tree

1 file changed

+67
-5
lines changed

1 file changed

+67
-5
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def _render_shapes(
204204
transformed_element[col_for_color] = color_vector if color_source_vector is None else color_source_vector
205205
# Render shapes with datashader
206206
color_by_categorical = col_for_color is not None and color_source_vector is not None
207+
207208
aggregate_with_reduction = None
209+
continuous_nan_shapes = None
208210
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
209211
if color_by_categorical:
210212
# add nan as a category so that shapes with nan value are colored in the nan color
@@ -221,6 +223,13 @@ def _render_shapes(
221223
)
222224
# save min and max values for drawing the colorbar
223225
aggregate_with_reduction = (agg.min(), agg.max())
226+
227+
# nan shapes need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods)
228+
transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()]
229+
if len(transformed_element_nan_color) > 0:
230+
continuous_nan_shapes = _datashader_aggregate_with_function(
231+
"any", cvs, transformed_element_nan_color, None, "shapes"
232+
)
224233
else:
225234
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.count())
226235
# render outlines if needed
@@ -281,6 +290,18 @@ def _render_shapes(
281290
min_alpha=np.min([254, render_params.fill_alpha * 255]),
282291
)
283292

293+
if continuous_nan_shapes is not None:
294+
# for coloring by continuous variable: render nan shapes separately
295+
nan_color = render_params.cmap_params.na_color
296+
if isinstance(nan_color, str) and nan_color.startswith("#") and len(nan_color) == 9:
297+
nan_color = nan_color[:7]
298+
continuous_nan_shapes = ds.tf.shade(
299+
continuous_nan_shapes,
300+
cmap=nan_color,
301+
how="linear",
302+
min_alpha=np.min([254, render_params.fill_alpha * 255]),
303+
)
304+
284305
# shade outlines if needed
285306
outline_color = render_params.outline_params.outline_color
286307
if isinstance(outline_color, str) and outline_color.startswith("#") and len(outline_color) == 9:
@@ -298,6 +319,17 @@ def _render_shapes(
298319
how="linear",
299320
)
300321

322+
if continuous_nan_shapes is not None:
323+
# for coloring by continuous variable: render nan points separately
324+
rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_shapes, factor, ax)
325+
_ax_show_and_transform(
326+
rgba_image_nan,
327+
trans_data_nan,
328+
ax,
329+
zorder=render_params.zorder,
330+
alpha=render_params.fill_alpha,
331+
extent=x_ext + y_ext,
332+
)
301333
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
302334
_cax = _ax_show_and_transform(
303335
rgba_image,
@@ -335,7 +367,7 @@ def _render_shapes(
335367
_cax = _get_collection_shape(
336368
shapes=shapes,
337369
s=render_params.scale,
338-
c=color_vector,
370+
c=color_vector.copy(), # copy bc c is modified in _get_collection_shape
339371
render_params=render_params,
340372
rasterized=sc_settings._vector_friendly,
341373
cmap=render_params.cmap_params.cmap,
@@ -355,8 +387,8 @@ def _render_shapes(
355387
# If the user passed a Normalize object with vmin/vmax we'll use those,
356388
# if not we'll use the min/max of the color_vector
357389
_cax.set_clim(
358-
vmin=render_params.cmap_params.norm.vmin or min(color_vector),
359-
vmax=render_params.cmap_params.norm.vmax or max(color_vector),
390+
vmin=render_params.cmap_params.norm.vmin or np.nanmin(color_vector),
391+
vmax=render_params.cmap_params.norm.vmax or np.nanmax(color_vector),
360392
)
361393

362394
if len(set(color_vector)) != 1 or list(set(color_vector))[0] != to_hex(render_params.cmap_params.na_color):
@@ -578,7 +610,9 @@ def _render_points(
578610
)
579611
if color_by_categorical and transformed_element[col_for_color].values.dtype == object:
580612
transformed_element[col_for_color] = transformed_element[col_for_color].astype("category")
613+
581614
aggregate_with_reduction = None
615+
continuous_nan_points = None
582616
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
583617
if color_by_categorical:
584618
# add nan as category so that nan points are shown in the nan color
@@ -593,10 +627,15 @@ def _render_points(
593627
)
594628
agg = _datashader_aggregate_with_function(
595629
render_params.ds_reduction, cvs, transformed_element, col_for_color, "points"
596-
) # TODO: if color column contains NaN values, compute a second aggregation (only with NaN),
597-
# maybe using ==NaN and then "any" as reduction => render grey and then layer colored points on top
630+
)
598631
# save min and max values for drawing the colorbar
599632
aggregate_with_reduction = (agg.min(), agg.max())
633+
# nan points need to be rendered separately (else: invisible, bc nan is skipped by aggregation methods)
634+
transformed_element_nan_color = transformed_element[transformed_element[col_for_color].isnull()]
635+
if len(transformed_element_nan_color) > 0:
636+
continuous_nan_points = _datashader_aggregate_with_function(
637+
"any", cvs, transformed_element_nan_color, None, "points"
638+
)
600639
else:
601640
agg = cvs.points(transformed_element, "x", "y", agg=ds.count())
602641

@@ -655,6 +694,29 @@ def _render_points(
655694
how="linear",
656695
)
657696

697+
if continuous_nan_points is not None:
698+
# for coloring by continuous variable: render nan points separately
699+
nan_color = render_params.cmap_params.na_color
700+
if isinstance(nan_color, str) and nan_color.startswith("#") and len(nan_color) == 9:
701+
nan_color = nan_color[:7]
702+
continuous_nan_points = ds.tf.spread(continuous_nan_points, px=px, how=spread_how)
703+
continuous_nan_points = ds.tf.shade(
704+
continuous_nan_points,
705+
cmap=nan_color,
706+
how="linear",
707+
)
708+
709+
if continuous_nan_points is not None:
710+
# for coloring by continuous variable: render nan points separately
711+
rgba_image_nan, trans_data_nan = _create_image_from_datashader_result(continuous_nan_points, factor, ax)
712+
_ax_show_and_transform(
713+
rgba_image_nan,
714+
trans_data_nan,
715+
ax,
716+
zorder=render_params.zorder,
717+
alpha=render_params.alpha,
718+
extent=x_ext + y_ext,
719+
)
658720
rgba_image, trans_data = _create_image_from_datashader_result(ds_result, factor, ax)
659721
_ax_show_and_transform(
660722
rgba_image,

0 commit comments

Comments
 (0)