@@ -207,6 +207,8 @@ def _render_shapes(
207207 aggregate_with_reduction = None
208208 if col_for_color is not None and (render_params .groups is None or len (render_params .groups ) > 1 ):
209209 if color_by_categorical :
210+ # add nan as a category so that shapes with nan value are colored in the nan color
211+ transformed_element [col_for_color ] = transformed_element [col_for_color ].cat .add_categories ("nan" )
210212 agg = cvs .polygons (transformed_element , geometry = "geometry" , agg = ds .by (col_for_color , ds .count ()))
211213 else :
212214 reduction_name = render_params .ds_reduction if render_params .ds_reduction is not None else "mean"
@@ -562,6 +564,14 @@ def _render_points(
562564 # use datashader for the visualization of points
563565 cvs = ds .Canvas (plot_width = plot_width , plot_height = plot_height , x_range = x_ext , y_range = y_ext )
564566
567+ # in case we are coloring by a column in table
568+ if col_for_color is not None and col_for_color not in transformed_element .columns :
569+ if color_source_vector is not None :
570+ transformed_element = transformed_element .assign (col_for_color = pd .Series (color_source_vector ))
571+ else :
572+ transformed_element = transformed_element .assign (col_for_color = pd .Series (color_vector ))
573+ transformed_element = transformed_element .rename (columns = {"col_for_color" : col_for_color })
574+
565575 color_by_categorical = col_for_color is not None and transformed_element [col_for_color ].values .dtype in (
566576 object ,
567577 "categorical" ,
@@ -571,6 +581,9 @@ def _render_points(
571581 aggregate_with_reduction = None
572582 if col_for_color is not None and (render_params .groups is None or len (render_params .groups ) > 1 ):
573583 if color_by_categorical :
584+ # add nan as category so that nan points are shown in the nan color
585+ transformed_element [col_for_color ] = transformed_element [col_for_color ].cat .as_known ()
586+ transformed_element [col_for_color ] = transformed_element [col_for_color ].cat .add_categories ("nan" )
574587 agg = cvs .points (transformed_element , "x" , "y" , agg = ds .by (col_for_color , ds .count ()))
575588 else :
576589 reduction_name = render_params .ds_reduction if render_params .ds_reduction is not None else "sum"
@@ -580,7 +593,8 @@ def _render_points(
580593 )
581594 agg = _datashader_aggregate_with_function (
582595 render_params .ds_reduction , cvs , transformed_element , col_for_color , "points"
583- )
596+ ) # TODO: if color column contains NaN values, compute a second aggregation (only with NaN),
597+ # maybe using ==NaN and then "any" as reduction => render grey and then layer colored points on top
584598 # save min and max values for drawing the colorbar
585599 aggregate_with_reduction = (agg .min (), agg .max ())
586600 else :
0 commit comments