1818from matplotlib .colors import ListedColormap , Normalize
1919from scanpy ._settings import settings as sc_settings
2020from spatialdata import get_extent
21- from spatialdata .models import PointsModel , get_table_keys
22- from spatialdata .transformations import (
23- set_transformation ,
24- )
21+ from spatialdata .models import PointsModel , ShapesModel , get_table_keys
22+ from spatialdata .transformations import get_transformation , set_transformation
23+ from spatialdata .transformations .transformations import Identity
2524from xarray import DataTree
2625
2726from spatialdata_plot ._logging import logger
4443 _get_colors_for_categorical_obs ,
4544 _get_extent_and_range_for_datashader_canvas ,
4645 _get_linear_colormap ,
46+ _get_transformation_matrix_for_datashader ,
4747 _is_coercable_to_float ,
4848 _map_color_seg ,
4949 _maybe_set_colors ,
@@ -148,7 +148,7 @@ def _render_shapes(
148148 colorbar = False if col_for_color is None else legend_params .colorbar
149149
150150 # Apply the transformation to the PatchCollection's paths
151- trans , _ = _prepare_transformation (sdata_filt .shapes [element ], coordinate_system )
151+ trans , trans_data = _prepare_transformation (sdata_filt .shapes [element ], coordinate_system )
152152
153153 shapes = gpd .GeoDataFrame (shapes , geometry = "geometry" )
154154
@@ -168,14 +168,6 @@ def _render_shapes(
168168 )
169169
170170 if method == "datashader" :
171- trans += ax .transData
172-
173- plot_width , plot_height , x_ext , y_ext , factor = _get_extent_and_range_for_datashader_canvas (
174- sdata_filt .shapes [element ], coordinate_system , ax , fig_params
175- )
176-
177- cvs = ds .Canvas (plot_width = plot_width , plot_height = plot_height , x_range = x_ext , y_range = y_ext )
178-
179171 _geometry = shapes ["geometry" ]
180172 is_point = _geometry .type == "Point"
181173
@@ -184,36 +176,48 @@ def _render_shapes(
184176 scale = shapes [is_point ]["radius" ] * render_params .scale
185177 sdata_filt .shapes [element ].loc [is_point , "geometry" ] = _geometry [is_point ].buffer (scale .to_numpy ())
186178
179+ # apply transformations to the individual points
180+ element_trans = get_transformation (sdata_filt .shapes [element ])
181+ tm = _get_transformation_matrix_for_datashader (element_trans )
182+ transformed_element = sdata_filt .shapes [element ].transform (
183+ lambda x : (np .hstack ([x , np .ones ((x .shape [0 ], 1 ))]) @ tm )[:, :2 ]
184+ )
185+ transformed_element = ShapesModel .parse (
186+ gpd .GeoDataFrame (data = sdata_filt .shapes [element ].drop ("geometry" , axis = 1 ), geometry = transformed_element )
187+ )
188+
189+ plot_width , plot_height , x_ext , y_ext , factor = _get_extent_and_range_for_datashader_canvas (
190+ transformed_element , coordinate_system , ax , fig_params
191+ )
192+
193+ cvs = ds .Canvas (plot_width = plot_width , plot_height = plot_height , x_range = x_ext , y_range = y_ext )
194+
187195 # in case we are coloring by a column in table
188- if col_for_color is not None and col_for_color not in sdata_filt .shapes [element ].columns :
189- sdata_filt .shapes [element ][col_for_color ] = (
190- color_vector if color_source_vector is None else color_source_vector
191- )
196+ if col_for_color is not None and col_for_color not in transformed_element .columns :
197+ transformed_element [col_for_color ] = color_vector if color_source_vector is None else color_source_vector
192198 # Render shapes with datashader
193199 color_by_categorical = col_for_color is not None and color_source_vector is not None
194200 aggregate_with_reduction = None
195201 if col_for_color is not None and (render_params .groups is None or len (render_params .groups ) > 1 ):
196202 if color_by_categorical :
197- agg = cvs .polygons (
198- sdata_filt .shapes [element ], geometry = "geometry" , agg = ds .by (col_for_color , ds .count ())
199- )
203+ agg = cvs .polygons (transformed_element , geometry = "geometry" , agg = ds .by (col_for_color , ds .count ()))
200204 else :
201205 reduction_name = render_params .ds_reduction if render_params .ds_reduction is not None else "mean"
202206 logger .info (
203207 f'Using the datashader reduction "{ reduction_name } ". "max" will give an output very close '
204208 "to the matplotlib result."
205209 )
206210 agg = _datashader_aggregate_with_function (
207- render_params .ds_reduction , cvs , sdata_filt . shapes [ element ] , col_for_color , "shapes"
211+ render_params .ds_reduction , cvs , transformed_element , col_for_color , "shapes"
208212 )
209213 # save min and max values for drawing the colorbar
210214 aggregate_with_reduction = (agg .min (), agg .max ())
211215 else :
212- agg = cvs .polygons (sdata_filt . shapes [ element ] , geometry = "geometry" , agg = ds .count ())
216+ agg = cvs .polygons (transformed_element , geometry = "geometry" , agg = ds .count ())
213217 # render outlines if needed
214218 if (render_outlines := render_params .outline_alpha ) > 0 :
215219 agg_outlines = cvs .line (
216- sdata_filt . shapes [ element ] ,
220+ transformed_element ,
217221 geometry = "geometry" ,
218222 line_width = render_params .outline_params .linewidth ,
219223 )
@@ -287,13 +291,23 @@ def _render_shapes(
287291
288292 rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
289293 _cax = _ax_show_and_transform (
290- rgba_image , trans_data , ax , zorder = render_params .zorder , alpha = render_params .fill_alpha
294+ rgba_image ,
295+ trans_data ,
296+ ax ,
297+ zorder = render_params .zorder ,
298+ alpha = render_params .fill_alpha ,
299+ extent = x_ext + y_ext ,
291300 )
292301 # render outline image if needed
293302 if render_outlines :
294303 rgba_image , trans_data = _create_image_from_datashader_result (ds_outlines , factor , ax )
295304 _ax_show_and_transform (
296- rgba_image , trans_data , ax , zorder = render_params .zorder , alpha = render_params .outline_alpha
305+ rgba_image ,
306+ trans_data ,
307+ ax ,
308+ zorder = render_params .zorder ,
309+ alpha = render_params .outline_alpha ,
310+ extent = x_ext + y_ext ,
297311 )
298312
299313 cax = None
@@ -330,7 +344,7 @@ def _render_shapes(
330344
331345 if not values_are_categorical :
332346 # If the user passed a Normalize object with vmin/vmax we'll use those,
333- # # if not we'll use the min/max of the color_vector
347+ # if not we'll use the min/max of the color_vector
334348 _cax .set_clim (
335349 vmin = render_params .cmap_params .norm .vmin or min (color_vector ),
336350 vmax = render_params .cmap_params .norm .vmax or max (color_vector ),
@@ -468,7 +482,7 @@ def _render_points(
468482 if color_source_vector is None and render_params .transfunc is not None :
469483 color_vector = render_params .transfunc (color_vector )
470484
471- _ , trans_data = _prepare_transformation (sdata .points [element ], coordinate_system , ax )
485+ trans , trans_data = _prepare_transformation (sdata .points [element ], coordinate_system , ax )
472486
473487 norm = copy (render_params .cmap_params .norm )
474488
@@ -491,8 +505,15 @@ def _render_points(
491505 # use dpi/100 as a factor for cases where dpi!=100
492506 px = int (np .round (np .sqrt (render_params .size ) * (fig_params .fig .dpi / 100 )))
493507
508+ # apply transformations
509+ transformed_element = PointsModel .parse (
510+ trans .transform (sdata_filt .points [element ][["x" , "y" ]]),
511+ annotation = sdata_filt .points [element ][sdata_filt .points [element ].columns .drop (["x" , "y" ])],
512+ transformations = {coordinate_system : Identity ()},
513+ )
514+
494515 plot_width , plot_height , x_ext , y_ext , factor = _get_extent_and_range_for_datashader_canvas (
495- sdata_filt . points [ element ] , coordinate_system , ax , fig_params
516+ transformed_element , coordinate_system , ax , fig_params
496517 )
497518
498519 # use datashader for the visualization of points
@@ -502,20 +523,20 @@ def _render_points(
502523 aggregate_with_reduction = None
503524 if col_for_color is not None and (render_params .groups is None or len (render_params .groups ) > 1 ):
504525 if color_by_categorical :
505- agg = cvs .points (sdata_filt . points [ element ] , "x" , "y" , agg = ds .by (col_for_color , ds .count ()))
526+ agg = cvs .points (transformed_element , "x" , "y" , agg = ds .by (col_for_color , ds .count ()))
506527 else :
507528 reduction_name = render_params .ds_reduction if render_params .ds_reduction is not None else "sum"
508529 logger .info (
509530 f'Using the datashader reduction "{ reduction_name } ". "max" will give an output very close '
510531 "to the matplotlib result."
511532 )
512533 agg = _datashader_aggregate_with_function (
513- render_params .ds_reduction , cvs , sdata_filt . points [ element ] , col_for_color , "points"
534+ render_params .ds_reduction , cvs , transformed_element , col_for_color , "points"
514535 )
515536 # save min and max values for drawing the colorbar
516537 aggregate_with_reduction = (agg .min (), agg .max ())
517538 else :
518- agg = cvs .points (sdata_filt . points [ element ] , "x" , "y" , agg = ds .count ())
539+ agg = cvs .points (transformed_element , "x" , "y" , agg = ds .count ())
519540
520541 if norm .vmin is not None or norm .vmax is not None :
521542 norm .vmin = np .min (agg ) if norm .vmin is None else norm .vmin
@@ -573,7 +594,14 @@ def _render_points(
573594 )
574595
575596 rgba_image , trans_data = _create_image_from_datashader_result (ds_result , factor , ax )
576- _ax_show_and_transform (rgba_image , trans_data , ax , zorder = render_params .zorder , alpha = render_params .alpha )
597+ _ax_show_and_transform (
598+ rgba_image ,
599+ trans_data ,
600+ ax ,
601+ zorder = render_params .zorder ,
602+ alpha = render_params .alpha ,
603+ extent = x_ext + y_ext ,
604+ )
577605
578606 cax = None
579607 if aggregate_with_reduction is not None :
0 commit comments