1010from matplotlib .colors import Normalize
1111from spatialdata import SpatialData , deepcopy
1212from spatialdata .models import PointsModel , TableModel
13- from spatialdata .transformations import Affine , Identity , MapAxis , Scale , Sequence , Translation
13+ from spatialdata .transformations import (
14+ Affine ,
15+ Identity ,
16+ MapAxis ,
17+ Scale ,
18+ Sequence ,
19+ Translation ,
20+ )
1421from spatialdata .transformations ._utils import _set_transformations
1522
1623import spatialdata_plot # noqa: F401
@@ -58,7 +65,9 @@ def test_plot_coloring_with_palette(self, sdata_blobs: SpatialData):
5865 sdata_blobs ["table" ].obs ["region" ] = ["blobs_points" ] * sdata_blobs ["table" ].n_obs
5966 sdata_blobs ["table" ].uns ["spatialdata_attrs" ]["region" ] = "blobs_points"
6067 sdata_blobs .pl .render_points (
61- color = "genes" , groups = ["gene_a" , "gene_b" ], palette = ["lightgreen" , "darkblue" ]
68+ color = "genes" ,
69+ groups = ["gene_a" , "gene_b" ],
70+ palette = ["lightgreen" , "darkblue" ],
6271 ).pl .show ()
6372
6473 def test_plot_coloring_with_cmap (self , sdata_blobs : SpatialData ):
@@ -81,35 +90,51 @@ def test_plot_color_recognises_actual_color_as_color(self, sdata_blobs: SpatialD
8190 def test_plot_points_coercable_categorical_color (self , sdata_blobs : SpatialData ):
8291 n_obs = len (sdata_blobs ["blobs_points" ])
8392 adata = AnnData (
84- RNG .normal (size = (n_obs , 10 )), obs = pd .DataFrame (RNG .normal (size = (n_obs , 3 )), columns = ["a" , "b" , "c" ])
93+ RNG .normal (size = (n_obs , 10 )),
94+ obs = pd .DataFrame (RNG .normal (size = (n_obs , 3 )), columns = ["a" , "b" , "c" ]),
8595 )
8696 adata .obs ["instance_id" ] = np .arange (adata .n_obs )
8797 adata .obs ["category" ] = RNG .choice (["a" , "b" , "c" ], size = adata .n_obs )
8898 adata .obs ["instance_id" ] = list (range (adata .n_obs ))
8999 adata .obs ["region" ] = "blobs_points"
90- table = TableModel .parse (adata = adata , region_key = "region" , instance_key = "instance_id" , region = "blobs_points" )
100+ table = TableModel .parse (
101+ adata = adata ,
102+ region_key = "region" ,
103+ instance_key = "instance_id" ,
104+ region = "blobs_points" ,
105+ )
91106 sdata_blobs ["other_table" ] = table
92107
93108 sdata_blobs .pl .render_points ("blobs_points" , color = "category" ).pl .show ()
94109
95110 def test_plot_points_categorical_color (self , sdata_blobs : SpatialData ):
96111 n_obs = len (sdata_blobs ["blobs_points" ])
97112 adata = AnnData (
98- RNG .normal (size = (n_obs , 10 )), obs = pd .DataFrame (RNG .normal (size = (n_obs , 3 )), columns = ["a" , "b" , "c" ])
113+ RNG .normal (size = (n_obs , 10 )),
114+ obs = pd .DataFrame (RNG .normal (size = (n_obs , 3 )), columns = ["a" , "b" , "c" ]),
99115 )
100116 adata .obs ["instance_id" ] = np .arange (adata .n_obs )
101117 adata .obs ["category" ] = RNG .choice (["a" , "b" , "c" ], size = adata .n_obs )
102118 adata .obs ["instance_id" ] = list (range (adata .n_obs ))
103119 adata .obs ["region" ] = "blobs_points"
104- table = TableModel .parse (adata = adata , region_key = "region" , instance_key = "instance_id" , region = "blobs_points" )
120+ table = TableModel .parse (
121+ adata = adata ,
122+ region_key = "region" ,
123+ instance_key = "instance_id" ,
124+ region = "blobs_points" ,
125+ )
105126 sdata_blobs ["other_table" ] = table
106127
107128 sdata_blobs ["other_table" ].obs ["category" ] = sdata_blobs ["other_table" ].obs ["category" ].astype ("category" )
108129 sdata_blobs .pl .render_points ("blobs_points" , color = "category" ).pl .show ()
109130
110131 def test_plot_datashader_continuous_color (self , sdata_blobs : SpatialData ):
111132 sdata_blobs .pl .render_points (
112- element = "blobs_points" , size = 40 , color = "instance_id" , alpha = 0.6 , method = "datashader"
133+ element = "blobs_points" ,
134+ size = 40 ,
135+ color = "instance_id" ,
136+ alpha = 0.6 ,
137+ method = "datashader" ,
113138 ).pl .show ()
114139
115140 def test_plot_points_categorical_color_column_matplotlib (self , sdata_blobs : SpatialData ):
@@ -131,32 +156,56 @@ def test_plot_datashader_matplotlib_stack(self, sdata_blobs: SpatialData):
131156
132157 def test_plot_datashader_can_color_by_category (self , sdata_blobs : SpatialData ):
133158 sdata_blobs .pl .render_points (
134- color = "genes" , groups = "gene_b" , palette = "lightgreen" , size = 20 , method = "datashader"
159+ color = "genes" ,
160+ groups = "gene_b" ,
161+ palette = "lightgreen" ,
162+ size = 20 ,
163+ method = "datashader" ,
135164 ).pl .show ()
136165
137166 def test_plot_datashader_can_use_sum_as_reduction (self , sdata_blobs : SpatialData ):
138167 sdata_blobs .pl .render_points (
139- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "sum"
168+ element = "blobs_points" ,
169+ size = 40 ,
170+ color = "instance_id" ,
171+ method = "datashader" ,
172+ datashader_reduction = "sum" ,
140173 ).pl .show ()
141174
142175 def test_plot_datashader_can_use_mean_as_reduction (self , sdata_blobs : SpatialData ):
143176 sdata_blobs .pl .render_points (
144- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "mean"
177+ element = "blobs_points" ,
178+ size = 40 ,
179+ color = "instance_id" ,
180+ method = "datashader" ,
181+ datashader_reduction = "mean" ,
145182 ).pl .show ()
146183
147184 def test_plot_datashader_can_use_any_as_reduction (self , sdata_blobs : SpatialData ):
148185 sdata_blobs .pl .render_points (
149- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "any"
186+ element = "blobs_points" ,
187+ size = 40 ,
188+ color = "instance_id" ,
189+ method = "datashader" ,
190+ datashader_reduction = "any" ,
150191 ).pl .show ()
151192
152193 def test_plot_datashader_can_use_count_as_reduction (self , sdata_blobs : SpatialData ):
153194 sdata_blobs .pl .render_points (
154- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "count"
195+ element = "blobs_points" ,
196+ size = 40 ,
197+ color = "instance_id" ,
198+ method = "datashader" ,
199+ datashader_reduction = "count" ,
155200 ).pl .show ()
156201
157202 def test_plot_datashader_can_use_std_as_reduction (self , sdata_blobs : SpatialData ):
158203 sdata_blobs .pl .render_points (
159- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "std"
204+ element = "blobs_points" ,
205+ size = 40 ,
206+ color = "instance_id" ,
207+ method = "datashader" ,
208+ datashader_reduction = "std" ,
160209 ).pl .show ()
161210
162211 def test_plot_datashader_can_use_std_as_reduction_not_all_zero (self , sdata_blobs : SpatialData ):
@@ -168,34 +217,59 @@ def test_plot_datashader_can_use_std_as_reduction_not_all_zero(self, sdata_blobs
168217 temp .loc [195 , "instance_id" ] = 13
169218 blob ["blobs_points" ] = PointsModel .parse (dask .dataframe .from_pandas (temp , 1 ), coordinates = {"x" : "x" , "y" : "y" })
170219 blob .pl .render_points (
171- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "std"
220+ element = "blobs_points" ,
221+ size = 40 ,
222+ color = "instance_id" ,
223+ method = "datashader" ,
224+ datashader_reduction = "std" ,
172225 ).pl .show ()
173226
174227 def test_plot_datashader_can_use_var_as_reduction (self , sdata_blobs : SpatialData ):
175228 sdata_blobs .pl .render_points (
176- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "var"
229+ element = "blobs_points" ,
230+ size = 40 ,
231+ color = "instance_id" ,
232+ method = "datashader" ,
233+ datashader_reduction = "var" ,
177234 ).pl .show ()
178235
179236 def test_plot_datashader_can_use_max_as_reduction (self , sdata_blobs : SpatialData ):
180237 sdata_blobs .pl .render_points (
181- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "max"
238+ element = "blobs_points" ,
239+ size = 40 ,
240+ color = "instance_id" ,
241+ method = "datashader" ,
242+ datashader_reduction = "max" ,
182243 ).pl .show ()
183244
184245 def test_plot_datashader_can_use_min_as_reduction (self , sdata_blobs : SpatialData ):
185246 sdata_blobs .pl .render_points (
186- element = "blobs_points" , size = 40 , color = "instance_id" , method = "datashader" , datashader_reduction = "min"
247+ element = "blobs_points" ,
248+ size = 40 ,
249+ color = "instance_id" ,
250+ method = "datashader" ,
251+ datashader_reduction = "min" ,
187252 ).pl .show ()
188253
189254 def test_plot_mpl_and_datashader_point_sizes_agree_after_altered_dpi (self , sdata_blobs : SpatialData ):
190255 sdata_blobs .pl .render_points (element = "blobs_points" , size = 400 , color = "blue" ).pl .render_points (
191- element = "blobs_points" , size = 400 , color = "yellow" , method = "datashader" , alpha = 0.8
256+ element = "blobs_points" ,
257+ size = 400 ,
258+ color = "yellow" ,
259+ method = "datashader" ,
260+ alpha = 0.8 ,
192261 ).pl .show (dpi = 200 )
193262
194263 def test_plot_points_transformed_ds_agrees_with_mpl (self ):
195264 sdata = SpatialData (
196265 points = {
197266 "points1" : PointsModel .parse (
198- pd .DataFrame ({"y" : [0 , 0 , 10 , 10 , 4 , 6 , 4 , 6 ], "x" : [0 , 10 , 10 , 0 , 4 , 6 , 6 , 4 ]}),
267+ pd .DataFrame (
268+ {
269+ "y" : [0 , 0 , 10 , 10 , 4 , 6 , 4 , 6 ],
270+ "x" : [0 , 10 , 10 , 0 , 4 , 6 , 6 , 4 ],
271+ }
272+ ),
199273 transformations = {"global" : Scale ([2 , 2 ], ("y" , "x" ))},
200274 )
201275 },
@@ -228,12 +302,18 @@ def test_plot_datashader_can_transform_points(self, sdata_blobs: SpatialData):
228302
229303 def test_plot_can_use_norm_with_clip (self , sdata_blobs : SpatialData ):
230304 sdata_blobs .pl .render_points (
231- color = "instance_id" , size = 40 , norm = Normalize (3 , 7 , clip = True ), cmap = _viridis_with_under_over ()
305+ color = "instance_id" ,
306+ size = 40 ,
307+ norm = Normalize (3 , 7 , clip = True ),
308+ cmap = _viridis_with_under_over (),
232309 ).pl .show ()
233310
234311 def test_plot_can_use_norm_without_clip (self , sdata_blobs : SpatialData ):
235312 sdata_blobs .pl .render_points (
236- color = "instance_id" , size = 40 , norm = Normalize (3 , 7 , clip = False ), cmap = _viridis_with_under_over ()
313+ color = "instance_id" ,
314+ size = 40 ,
315+ norm = Normalize (3 , 7 , clip = False ),
316+ cmap = _viridis_with_under_over (),
237317 ).pl .show ()
238318
239319 def test_plot_datashader_can_use_norm_with_clip (self , sdata_blobs : SpatialData ):
@@ -290,7 +370,12 @@ def test_plot_can_annotate_points_with_table_obs(self, sdata_blobs: SpatialData)
290370 obs ["extra_feature" ] = [1 , 2 ] * 100
291371
292372 table = AnnData (X = feature_matrix , var = pd .DataFrame (index = var_names ), obs = obs )
293- table = TableModel .parse (table , region = "blobs_points" , region_key = "region" , instance_key = "instance_id" )
373+ table = TableModel .parse (
374+ table ,
375+ region = "blobs_points" ,
376+ region_key = "region" ,
377+ instance_key = "instance_id" ,
378+ )
294379 sdata_blobs ["points_table" ] = table
295380
296381 sdata_blobs .pl .render_points ("blobs_points" , color = "extra_feature" , size = 10 ).pl .show ()
@@ -308,7 +393,12 @@ def test_plot_can_annotate_points_with_table_X(self, sdata_blobs: SpatialData):
308393 obs ["region" ].astype ("category" )
309394
310395 table = AnnData (X = feature_matrix , var = pd .DataFrame (index = var_names ), obs = obs )
311- table = TableModel .parse (table , region = "blobs_points" , region_key = "region" , instance_key = "instance_id" )
396+ table = TableModel .parse (
397+ table ,
398+ region = "blobs_points" ,
399+ region_key = "region" ,
400+ instance_key = "instance_id" ,
401+ )
312402 sdata_blobs ["points_table" ] = table
313403
314404 sdata_blobs .pl .render_points ("blobs_points" , color = "feature0" , size = 10 ).pl .show ()
@@ -327,13 +417,20 @@ def test_plot_can_annotate_points_with_table_and_groups(self, sdata_blobs: Spati
327417 obs ["extra_feature_cat" ] = ["one" , "two" ] * 100
328418
329419 table = AnnData (X = feature_matrix , var = pd .DataFrame (index = var_names ), obs = obs )
330- table = TableModel .parse (table , region = "blobs_points" , region_key = "region" , instance_key = "instance_id" )
420+ table = TableModel .parse (
421+ table ,
422+ region = "blobs_points" ,
423+ region_key = "region" ,
424+ instance_key = "instance_id" ,
425+ )
331426 sdata_blobs ["points_table" ] = table
332427
333428 sdata_blobs .pl .render_points ("blobs_points" , color = "extra_feature_cat" , groups = "two" , size = 10 ).pl .show ()
334429
335430 def test_plot_can_annotate_points_with_table_layer (self , sdata_blobs : SpatialData ):
336431 nrows , ncols = 200 , 3
432+ # reset seed for reproducibility
433+ RNG .seed (42 )
337434 feature_matrix = RNG .random ((nrows , ncols ))
338435 var_names = [f"feature{ i } " for i in range (ncols )]
339436
@@ -345,7 +442,12 @@ def test_plot_can_annotate_points_with_table_layer(self, sdata_blobs: SpatialDat
345442 obs ["region" ].astype ("category" )
346443
347444 table = AnnData (X = feature_matrix , var = pd .DataFrame (index = var_names ), obs = obs )
348- table = TableModel .parse (table , region = "blobs_points" , region_key = "region" , instance_key = "instance_id" )
445+ table = TableModel .parse (
446+ table ,
447+ region = "blobs_points" ,
448+ region_key = "region" ,
449+ instance_key = "instance_id" ,
450+ )
349451 sdata_blobs ["points_table" ] = table
350452 sdata_blobs ["points_table" ].layers ["normalized" ] = RNG .random ((nrows , ncols ))
351453
0 commit comments