Skip to content

Commit fa2ec17

Browse files
committed
added warning
1 parent 323dcac commit fa2ec17

File tree

2 files changed

+38
-73
lines changed

2 files changed

+38
-73
lines changed

src/spatialdata_plot/pl/render.py

Lines changed: 31 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from matplotlib.cm import ScalarMappable
1818
from matplotlib.colors import ListedColormap, Normalize
1919
from scanpy._settings import settings as sc_settings
20-
from spatialdata import get_element_annotators, get_extent, get_values, join_spatialelement_table
20+
from spatialdata import get_extent, get_values, join_spatialelement_table
2121
from spatialdata.models import PointsModel, ShapesModel, get_table_keys
2222
from spatialdata.transformations import get_transformation, set_transformation
2323
from spatialdata.transformations.transformations import Identity
@@ -73,36 +73,24 @@ def _render_shapes(
7373
col_for_color = render_params.col_for_color
7474
groups = render_params.groups
7575
table_layer = render_params.table_layer
76-
table_name = render_params.table_name
7776

7877
sdata_filt = sdata.filter_by_coordinate_system(
7978
coordinate_system=coordinate_system,
8079
filter_tables=bool(render_params.table_name),
8180
)
82-
if table_name is None:
81+
82+
if (table_name := render_params.table_name) is None:
8383
table = None
8484
shapes = sdata_filt[element]
8585
else:
86-
# Check if the table actually annotates the element
87-
annotating_tables = get_element_annotators(sdata, element)
88-
if table_name not in annotating_tables:
89-
warnings.warn(
90-
f"Table '{table_name}' does not annotate element '{element}'",
91-
UserWarning,
92-
stacklevel=2,
93-
)
94-
# Fall back to no table
95-
table = None
96-
shapes = sdata_filt[element]
97-
else:
98-
element_dict, joined_table = join_spatialelement_table(
99-
sdata, spatial_element_names=element, table_name=table_name, how="inner"
100-
)
101-
sdata_filt[element] = shapes = element_dict[element]
102-
joined_table.uns["spatialdata_attrs"]["region"] = (
103-
joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist()
104-
)
105-
sdata_filt[table_name] = table = joined_table
86+
element_dict, joined_table = join_spatialelement_table(
87+
sdata, spatial_element_names=element, table_name=table_name, how="inner"
88+
)
89+
sdata_filt[element] = shapes = element_dict[element]
90+
joined_table.uns["spatialdata_attrs"]["region"] = (
91+
joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist()
92+
)
93+
sdata_filt[table_name] = table = joined_table
10694

10795
if (
10896
col_for_color is not None
@@ -502,39 +490,22 @@ def _render_points(
502490
dtype=points[["x", "y"]].values.dtype,
503491
)
504492
else:
505-
# Check if the table actually annotates the element
506-
annotating_tables = get_element_annotators(sdata, element)
507-
if table_name not in annotating_tables:
508-
warnings.warn(
509-
f"Table '{table_name}' does not annotate element '{element}'",
510-
UserWarning,
511-
stacklevel=2,
512-
)
513-
# Fall back to no table
514-
adata = AnnData(
515-
X=points[["x", "y"]].values,
516-
obs=points[coords].reset_index(),
517-
dtype=points[["x", "y"]].values.dtype,
518-
)
519-
else:
520-
adata_obs = sdata_filt[table_name].obs
521-
# if the points are colored by values in X (or a different layer), add the values to obs
522-
if col_for_color in sdata_filt[table_name].var_names:
523-
if table_layer is None:
524-
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy()
525-
else:
526-
adata_obs[col_for_color] = (
527-
sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy()
528-
)
529-
if groups is not None:
530-
adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)]
531-
adata = AnnData(
532-
X=points[["x", "y"]].values,
533-
obs=adata_obs,
534-
dtype=points[["x", "y"]].values.dtype,
535-
uns=sdata_filt[table_name].uns,
536-
)
537-
sdata_filt[table_name] = adata
493+
adata_obs = sdata_filt[table_name].obs
494+
# if the points are colored by values in X (or a different layer), add the values to obs
495+
if col_for_color in sdata_filt[table_name].var_names:
496+
if table_layer is None:
497+
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy()
498+
else:
499+
adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy()
500+
if groups is not None:
501+
adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)]
502+
adata = AnnData(
503+
X=points[["x", "y"]].values,
504+
obs=adata_obs,
505+
dtype=points[["x", "y"]].values.dtype,
506+
uns=sdata_filt[table_name].uns,
507+
)
508+
sdata_filt[table_name] = adata
538509

539510
# we can modify the sdata because of dealing with a copy
540511

@@ -1080,23 +1051,11 @@ def _render_labels(
10801051
instance_id = np.unique(label)
10811052
table = None
10821053
else:
1083-
# Check if the table actually annotates the element
1084-
annotating_tables = get_element_annotators(sdata, element)
1085-
if table_name not in annotating_tables:
1086-
warnings.warn(
1087-
f"Table '{table_name}' does not annotate element '{element}'",
1088-
UserWarning,
1089-
stacklevel=2,
1090-
)
1091-
# Fall back to no table
1092-
instance_id = np.unique(label)
1093-
table = None
1094-
else:
1095-
_, region_key, instance_key = get_table_keys(sdata[table_name])
1096-
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]
1054+
_, region_key, instance_key = get_table_keys(sdata[table_name])
1055+
table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])]
10971056

1098-
# get instance id based on subsetted table
1099-
instance_id = np.unique(table.obs[instance_key].values)
1057+
# get instance id based on subsetted table
1058+
instance_id = np.unique(table.obs[instance_key].values)
11001059

11011060
_, trans_data = _prepare_transformation(label, coordinate_system, ax)
11021061

src/spatialdata_plot/pl/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,8 @@ def _validate_label_render_params(
18331833

18341834
element_params[el]["table_name"] = None
18351835
element_params[el]["color"] = None
1836-
if (color := param_dict["color"]) is not None:
1836+
color = param_dict["color"]
1837+
if color is not None:
18371838
color, table_name = _validate_col_for_column_table(sdata, el, color, param_dict["table_name"], labels=True)
18381839
element_params[el]["table_name"] = table_name
18391840
element_params[el]["color"] = color
@@ -1995,6 +1996,11 @@ def _validate_col_for_column_table(
19951996
if table_name not in tables or (
19961997
col_for_color not in sdata[table_name].obs.columns and col_for_color not in sdata[table_name].var_names
19971998
):
1999+
warnings.warn(
2000+
f"Table '{table_name}' does not annotate element '{element_name}'.",
2001+
UserWarning,
2002+
stacklevel=2,
2003+
)
19982004
table_name = None
19992005
col_for_color = None
20002006
else:

0 commit comments

Comments
 (0)