|
17 | 17 | from matplotlib.cm import ScalarMappable |
18 | 18 | from matplotlib.colors import ListedColormap, Normalize |
19 | 19 | from scanpy._settings import settings as sc_settings |
20 | | -from spatialdata import get_extent |
| 20 | +from spatialdata import get_extent, join_spatialelement_table |
21 | 21 | from spatialdata.models import PointsModel, ShapesModel, get_table_keys |
22 | 22 | from spatialdata.transformations import get_transformation, set_transformation |
23 | 23 | from spatialdata.transformations.transformations import Identity |
@@ -76,13 +76,18 @@ def _render_shapes( |
76 | 76 | filter_tables=bool(render_params.table_name), |
77 | 77 | ) |
78 | 78 |
|
79 | | - shapes = sdata[element] |
80 | | - |
81 | 79 | if (table_name := render_params.table_name) is None: |
82 | 80 | table = None |
| 81 | + shapes = sdata_filt[element] |
83 | 82 | else: |
84 | | - _, region_key, _ = get_table_keys(sdata[table_name]) |
85 | | - table = sdata[table_name][sdata[table_name].obs[region_key].isin([element])] |
| 83 | + element_dict, joined_table = join_spatialelement_table( |
| 84 | + sdata, spatial_element_names=element, table_name=table_name, how="inner" |
| 85 | + ) |
| 86 | + sdata_filt[element] = shapes = element_dict[element] |
| 87 | + joined_table.uns["spatialdata_attrs"]["region"] = ( |
| 88 | + joined_table.obs[joined_table.uns["spatialdata_attrs"]["region_key"]].unique().tolist() |
| 89 | + ) |
| 90 | + sdata_filt[table_name] = table = joined_table |
86 | 91 |
|
87 | 92 | if ( |
88 | 93 | col_for_color is not None |
|
0 commit comments