Skip to content

Commit edbcb88

Browse files
author
Sonja Stockhaus
committed
fix ds coloring by values in table and by categorical with nan
1 parent ac4677f commit edbcb88

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ def _render_shapes(
207207
aggregate_with_reduction = None
208208
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
209209
if color_by_categorical:
210+
# add nan as a category so that shapes with nan value are colored in the nan color
211+
transformed_element[col_for_color] = transformed_element[col_for_color].cat.add_categories("nan")
210212
agg = cvs.polygons(transformed_element, geometry="geometry", agg=ds.by(col_for_color, ds.count()))
211213
else:
212214
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "mean"
@@ -562,6 +564,14 @@ def _render_points(
562564
# use datashader for the visualization of points
563565
cvs = ds.Canvas(plot_width=plot_width, plot_height=plot_height, x_range=x_ext, y_range=y_ext)
564566

567+
# in case we are coloring by a column in table
568+
if col_for_color is not None and col_for_color not in transformed_element.columns:
569+
if color_source_vector is not None:
570+
transformed_element = transformed_element.assign(col_for_color=pd.Series(color_source_vector))
571+
else:
572+
transformed_element = transformed_element.assign(col_for_color=pd.Series(color_vector))
573+
transformed_element = transformed_element.rename(columns={"col_for_color": col_for_color})
574+
565575
color_by_categorical = col_for_color is not None and transformed_element[col_for_color].values.dtype in (
566576
object,
567577
"categorical",
@@ -571,6 +581,9 @@ def _render_points(
571581
aggregate_with_reduction = None
572582
if col_for_color is not None and (render_params.groups is None or len(render_params.groups) > 1):
573583
if color_by_categorical:
584+
# add nan as category so that nan points are shown in the nan color
585+
transformed_element[col_for_color] = transformed_element[col_for_color].cat.as_known()
586+
transformed_element[col_for_color] = transformed_element[col_for_color].cat.add_categories("nan")
574587
agg = cvs.points(transformed_element, "x", "y", agg=ds.by(col_for_color, ds.count()))
575588
else:
576589
reduction_name = render_params.ds_reduction if render_params.ds_reduction is not None else "sum"
@@ -580,7 +593,8 @@ def _render_points(
580593
)
581594
agg = _datashader_aggregate_with_function(
582595
render_params.ds_reduction, cvs, transformed_element, col_for_color, "points"
583-
)
596+
) # TODO: if color column contains NaN values, compute a second aggregation (only with NaN),
597+
# maybe using ==NaN and then "any" as reduction => render grey and then layer colored points on top
584598
# save min and max values for drawing the colorbar
585599
aggregate_with_reduction = (agg.min(), agg.max())
586600
else:

0 commit comments

Comments
 (0)