|
17 | 17 | from matplotlib.cm import ScalarMappable |
18 | 18 | from matplotlib.colors import ListedColormap, Normalize |
19 | 19 | from scanpy._settings import settings as sc_settings |
20 | | -from spatialdata import get_element_annotators, get_extent, get_values, join_spatialelement_table |
| 20 | +from spatialdata import get_extent, get_values, join_spatialelement_table |
21 | 21 | from spatialdata.models import PointsModel, ShapesModel, get_table_keys |
22 | 22 | from spatialdata.transformations import get_transformation, set_transformation |
23 | 23 | from spatialdata.transformations.transformations import Identity |
@@ -73,36 +73,24 @@ def _render_shapes( |
73 | 73 | col_for_color = render_params.col_for_color |
74 | 74 | groups = render_params.groups |
75 | 75 | table_layer = render_params.table_layer |
76 | | - table_name = render_params.table_name |
77 | 76 |
|
78 | 77 | sdata_filt = sdata.filter_by_coordinate_system( |
79 | 78 | coordinate_system=coordinate_system, |
80 | 79 | filter_tables=bool(render_params.table_name), |
81 | 80 | ) |
82 | | - if table_name is None: |
| 81 | + |
| 82 | + if (table_name := render_params.table_name) is None: |
83 | 83 | table = None |
84 | 84 | shapes = sdata_filt[element] |
85 | 85 | else: |
86 | | - # Check if the table actually annotates the element |
87 | | - annotating_tables = get_element_annotators(sdata, element) |
88 | | - if table_name not in annotating_tables: |
89 | | - warnings.warn( |
90 | | - f"Table '{table_name}' does not annotate element '{element}'", |
91 | | - UserWarning, |
92 | | - stacklevel=2, |
93 | | - ) |
94 | | - # Fall back to no table |
95 | | - table = None |
96 | | - shapes = sdata_filt[element] |
97 | | - else: |
98 | | - element_dict, joined_table = join_spatialelement_table( |
99 | | - sdata, spatial_element_names=element, table_name=table_name, how="inner" |
100 | | - ) |
101 | | - sdata_filt[element] = shapes = element_dict[element] |
102 | | - joined_table.uns["spatialdata_attrs"]["region"] = ( |
103 | | - joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist() |
104 | | - ) |
105 | | - sdata_filt[table_name] = table = joined_table |
| 86 | + element_dict, joined_table = join_spatialelement_table( |
| 87 | + sdata, spatial_element_names=element, table_name=table_name, how="inner" |
| 88 | + ) |
| 89 | + sdata_filt[element] = shapes = element_dict[element] |
| 90 | + joined_table.uns["spatialdata_attrs"]["region"] = ( |
| 91 | + joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist() |
| 92 | + ) |
| 93 | + sdata_filt[table_name] = table = joined_table |
106 | 94 |
|
107 | 95 | if ( |
108 | 96 | col_for_color is not None |
@@ -502,39 +490,22 @@ def _render_points( |
502 | 490 | dtype=points[["x", "y"]].values.dtype, |
503 | 491 | ) |
504 | 492 | else: |
505 | | - # Check if the table actually annotates the element |
506 | | - annotating_tables = get_element_annotators(sdata, element) |
507 | | - if table_name not in annotating_tables: |
508 | | - warnings.warn( |
509 | | - f"Table '{table_name}' does not annotate element '{element}'", |
510 | | - UserWarning, |
511 | | - stacklevel=2, |
512 | | - ) |
513 | | - # Fall back to no table |
514 | | - adata = AnnData( |
515 | | - X=points[["x", "y"]].values, |
516 | | - obs=points[coords].reset_index(), |
517 | | - dtype=points[["x", "y"]].values.dtype, |
518 | | - ) |
519 | | - else: |
520 | | - adata_obs = sdata_filt[table_name].obs |
521 | | - # if the points are colored by values in X (or a different layer), add the values to obs |
522 | | - if col_for_color in sdata_filt[table_name].var_names: |
523 | | - if table_layer is None: |
524 | | - adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy() |
525 | | - else: |
526 | | - adata_obs[col_for_color] = ( |
527 | | - sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy() |
528 | | - ) |
529 | | - if groups is not None: |
530 | | - adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)] |
531 | | - adata = AnnData( |
532 | | - X=points[["x", "y"]].values, |
533 | | - obs=adata_obs, |
534 | | - dtype=points[["x", "y"]].values.dtype, |
535 | | - uns=sdata_filt[table_name].uns, |
536 | | - ) |
537 | | - sdata_filt[table_name] = adata |
| 493 | + adata_obs = sdata_filt[table_name].obs |
| 494 | + # if the points are colored by values in X (or a different layer), add the values to obs |
| 495 | + if col_for_color in sdata_filt[table_name].var_names: |
| 496 | + if table_layer is None: |
| 497 | + adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].X.flatten().copy() |
| 498 | + else: |
| 499 | + adata_obs[col_for_color] = sdata_filt[table_name][:, col_for_color].layers[table_layer].flatten().copy() |
| 500 | + if groups is not None: |
| 501 | + adata_obs = adata_obs[adata_obs[col_for_color].isin(groups)] |
| 502 | + adata = AnnData( |
| 503 | + X=points[["x", "y"]].values, |
| 504 | + obs=adata_obs, |
| 505 | + dtype=points[["x", "y"]].values.dtype, |
| 506 | + uns=sdata_filt[table_name].uns, |
| 507 | + ) |
| 508 | + sdata_filt[table_name] = adata |
538 | 509 |
|
539 | 510 | # we can modify the sdata because of dealing with a copy |
540 | 511 |
|
@@ -1080,23 +1051,11 @@ def _render_labels( |
1080 | 1051 | instance_id = np.unique(label) |
1081 | 1052 | table = None |
1082 | 1053 | else: |
1083 | | - # Check if the table actually annotates the element |
1084 | | - annotating_tables = get_element_annotators(sdata, element) |
1085 | | - if table_name not in annotating_tables: |
1086 | | - warnings.warn( |
1087 | | - f"Table '{table_name}' does not annotate element '{element}'", |
1088 | | - UserWarning, |
1089 | | - stacklevel=2, |
1090 | | - ) |
1091 | | - # Fall back to no table |
1092 | | - instance_id = np.unique(label) |
1093 | | - table = None |
1094 | | - else: |
1095 | | - _, region_key, instance_key = get_table_keys(sdata[table_name]) |
1096 | | - table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])] |
| 1054 | + _, region_key, instance_key = get_table_keys(sdata[table_name]) |
| 1055 | + table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])] |
1097 | 1056 |
|
1098 | | - # get instance id based on subsetted table |
1099 | | - instance_id = np.unique(table.obs[instance_key].values) |
| 1057 | + # get instance id based on subsetted table |
| 1058 | + instance_id = np.unique(table.obs[instance_key].values) |
1100 | 1059 |
|
1101 | 1060 | _, trans_data = _prepare_transformation(label, coordinate_system, ax) |
1102 | 1061 |
|
|
0 commit comments