@@ -567,9 +567,6 @@ def _get_subplots(num_images: int, ncols: int = 4, width: int = 4, height: int =
567567 Union[plt.Figure, plt.Axes]
568568 Matplotlib figure and axes object.
569569 """
570- # if num_images <= 1:
571- # raise ValueError("Number of images must be greater than 1.")
572-
573570 if num_images < ncols :
574571 nrows = 1
575572 ncols = num_images
@@ -733,8 +730,6 @@ def _set_color_source_vec(
733730 color = np .full (len (element ), na_color )
734731 return color , color , False
735732
736- # model = get_model(sdata[element_name])
737-
738733 # Figure out where to get the color from
739734 origins = _locate_value (value_key = value_to_plot , sdata = sdata , element_name = element_name , table_name = table_name )
740735
@@ -778,16 +773,13 @@ def _set_color_source_vec(
778773 palette = palette ,
779774 na_color = na_color ,
780775 )
776+
781777 color_source_vector = color_source_vector .set_categories (color_mapping .keys ())
782778 if color_mapping is None :
783779 raise ValueError ("Unable to create color palette." )
784780
785781 # do not rename categories, as colors need not be unique
786782 color_vector = color_source_vector .map (color_mapping )
787- if color_vector .isna ().any ():
788- if (na_cat_color := to_hex (na_color )) not in color_vector .categories :
789- color_vector = color_vector .add_categories ([na_cat_color ])
790- color_vector = color_vector .fillna (to_hex (na_color ))
791783
792784 return color_source_vector , color_vector , True
793785
@@ -808,44 +800,43 @@ def _map_color_seg(
808800 seg_boundaries : bool = False ,
809801) -> ArrayLike :
810802 cell_id = np .array (cell_id )
811- if color_vector is not None and isinstance (color_vector .dtype , pd .CategoricalDtype ):
812- # users wants to plot a categorical column
803+
804+ if pd .api .types .is_categorical_dtype (color_vector .dtype ):
805+ # Case A: users wants to plot a categorical column
813806 if np .any (color_source_vector .isna ()):
814807 cell_id [color_source_vector .isna ()] = 0
815- val_im : ArrayLike = map_array (seg , cell_id , color_vector .codes + 1 )
808+ val_im : ArrayLike = map_array (seg . copy () , cell_id , color_vector .codes + 1 )
816809 cols = colors .to_rgba_array (color_vector .categories )
817-
818810 elif pd .api .types .is_numeric_dtype (color_vector .dtype ):
819- # user wants to plot a continous column
811+ # Case B: user wants to plot a continous column
820812 if isinstance (color_vector , pd .Series ):
821813 color_vector = color_vector .to_numpy ()
822- val_im = map_array (seg , cell_id , color_vector )
823814 cols = cmap_params .cmap (cmap_params .norm (color_vector ))
824-
815+ val_im = map_array ( seg . copy (), cell_id , cell_id )
825816 else :
826- val_im = map_array (seg .copy (), cell_id , cell_id ) # replace with same seg id to remove missing segs
827-
828- if val_im .shape [0 ] == 1 :
829- val_im = np .squeeze (val_im , axis = 0 )
830- if "#" in str (color_vector [0 ]):
831- # we have hex colors
832- assert all (_is_color_like (c ) for c in color_vector ), "Not all values are color-like."
833- cols = colors .to_rgba_array (color_vector )
817+ # Case C: User didn't specify any colors
818+ if color_source_vector is not None and (
819+ set (color_vector ) == set (color_source_vector )
820+ and len (set (color_vector )) == 1
821+ and set (color_vector ) == {na_color }
822+ and not na_color_modified_by_user
823+ ):
824+ val_im = map_array (seg .copy (), cell_id , cell_id )
825+ RNG = default_rng (42 )
826+ cols = RNG .random ((len (color_vector ), 3 ))
834827 else :
835- cols = cmap_params .cmap (cmap_params .norm (color_vector ))
828+ # Case D: User didn't specify a column to color by, but modified the na_color
829+ val_im = map_array (seg .copy (), cell_id , cell_id )
830+ if "#" in str (color_vector [0 ]):
831+ # we have hex colors
832+ assert all (_is_color_like (c ) for c in color_vector ), "Not all values are color-like."
833+ cols = colors .to_rgba_array (color_vector )
834+ else :
835+ cols = cmap_params .cmap (cmap_params .norm (color_vector ))
836836
837837 if seg_erosionpx is not None :
838838 val_im [val_im == erosion (val_im , square (seg_erosionpx ))] = 0
839839
840- if color_source_vector is not None and (
841- set (color_vector ) == set (color_source_vector )
842- and len (set (color_vector )) == 1
843- and set (color_vector ) == {na_color }
844- and not na_color_modified_by_user
845- ):
846- RNG = default_rng (42 )
847- cols = RNG .random ((len (cols ), 3 ))
848-
849840 seg_im : ArrayLike = label2rgb (
850841 label = val_im ,
851842 colors = cols ,
@@ -948,7 +939,7 @@ def _get_categorical_color_mapping(
948939 else :
949940 base_mapping = _generate_base_categorial_color_mapping (adata , cluster_key , color_source_vector , na_color )
950941
951- return _modify_categorical_color_mapping (base_mapping , groups , palette )
942+ return _modify_categorical_color_mapping (mapping = base_mapping , groups = groups , palette = palette )
952943
953944
954945def _maybe_set_colors (
@@ -1587,19 +1578,14 @@ def _type_check_params(param_dict: dict[str, Any], element_type: str) -> dict[st
15871578
15881579 palette = param_dict ["palette" ]
15891580
1590- if (groups := param_dict .get ("groups" )) is not None and palette is None :
1591- warnings .warn (
1592- "Groups is specified but palette is not. Setting palette to default 'lightgray'" , UserWarning , stacklevel = 2
1593- )
1594- param_dict ["palette" ] = ["lightgray" for _ in range (len (groups ))]
1595-
15961581 if isinstance ((palette := param_dict ["palette" ]), list ):
15971582 if not all (isinstance (p , str ) for p in palette ):
15981583 raise ValueError ("If specified, parameter 'palette' must contain only strings." )
15991584 elif isinstance (palette , (str , type (None ))) and "palette" in param_dict :
16001585 param_dict ["palette" ] = [palette ] if palette is not None else None
16011586
16021587 if element_type in ["shapes" , "points" , "labels" ] and (palette := param_dict .get ("palette" )) is not None :
1588+ groups = param_dict .get ("groups" )
16031589 if groups is None :
16041590 raise ValueError ("When specifying 'palette', 'groups' must also be specified." )
16051591 if len (groups ) != len (palette ):
0 commit comments