77import pandas as pd
88from .png import Writer , from_array
99import numpy as np
10+ import itertools
1011
1112try :
1213 import xarray
@@ -293,31 +294,41 @@ def imshow(
293294 args = locals ()
294295 apply_default_cascade (args )
295296 labels = labels .copy ()
296- nslices = 1
297+ nslices_facet = 1
297298 if facet_col is not None :
298299 if isinstance (facet_col , str ):
299300 facet_col = img .dims .index (facet_col )
300- nslices = img .shape [facet_col ]
301- ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices
302- nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
301+ nslices_facet = img .shape [facet_col ]
302+ facet_slices = range (nslices_facet )
303+ ncols = int (facet_col_wrap ) if facet_col_wrap is not None else nslices_facet
304+ nrows = (
305+ nslices_facet // ncols + 1
306+ if nslices_facet % ncols
307+ else nslices_facet // ncols
308+ )
303309 else :
304310 nrows = 1
305311 ncols = 1
306312 if animation_frame is not None :
307313 if isinstance (animation_frame , str ):
308314 animation_frame = img .dims .index (animation_frame )
309- nslices = img .shape [animation_frame ]
315+ nslices_animation = img .shape [animation_frame ]
316+ animation_slices = range (nslices_animation )
310317 slice_through = (facet_col is not None ) or (animation_frame is not None )
311- slice_label = None
312- slices = range (nslices )
318+ double_slice_through = (facet_col is not None ) and (animation_frame is not None )
319+ facet_label = None
320+ animation_label = None
313321 # ----- Define x and y, set labels if img is an xarray -------------------
314322 if xarray_imported and isinstance (img , xarray .DataArray ):
315323 dims = list (img .dims )
316- if slice_through :
317- slice_index = facet_col if facet_col is not None else animation_frame
318- slices = img .coords [img .dims [slice_index ]].values
319- _ = dims .pop (slice_index )
320- slice_label = img .dims [slice_index ]
324+ if facet_col is not None :
325+ facet_slices = img .coords [img .dims [facet_col ]].values
326+ _ = dims .pop (facet_col )
327+ facet_label = img .dims [facet_col ]
328+ if animation_frame is not None :
329+ animation_slices = img .coords [img .dims [animation_frame ]].values
330+ _ = dims .pop (animation_frame )
331+ animation_label = img .dims [animation_frame ]
321332 y_label , x_label = dims [0 ], dims [1 ]
322333 # np.datetime64 is not handled correctly by go.Heatmap
323334 for ax in [x_label , y_label ]:
@@ -333,8 +344,10 @@ def imshow(
333344 labels ["x" ] = x_label
334345 if labels .get ("y" , None ) is None :
335346 labels ["y" ] = y_label
336- if labels .get ("slice" , None ) is None :
337- labels ["slice" ] = slice_label
347+ if labels .get ("animation_slice" , None ) is None :
348+ labels ["animation_slice" ] = animation_label
349+ if labels .get ("facet_slice" , None ) is None :
350+ labels ["facet_slice" ] = facet_label
338351 if labels .get ("color" , None ) is None :
339352 labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
340353 labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -371,11 +384,15 @@ def imshow(
371384 img = np .asanyarray (img )
372385 if facet_col is not None :
373386 img = np .moveaxis (img , facet_col , 0 )
387+ print (img .shape )
388+ if animation_frame is not None and animation_frame < facet_col :
389+ animation_frame += 1
374390 facet_col = True
375391 if animation_frame is not None :
376392 img = np .moveaxis (img , animation_frame , 0 )
393+ print (img .shape )
377394 animation_frame = True
378- args ["animation_frame" ] = (
395+ args ["animation_frame" ] = ( # TODO
379396 "slice" if labels .get ("slice" ) is None else labels ["slice" ]
380397 )
381398
@@ -431,9 +448,16 @@ def imshow(
431448 + "dimension of the img matrix."
432449 )
433450 if slice_through :
451+ iterables = ()
452+ if animation_frame is not None :
453+ iterables += (range (nslices_animation ),)
454+ if facet_col is not None :
455+ iterables += (range (nslices_facet ),)
434456 traces = [
435- go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" , name = str (i ))
436- for i , img_slice in enumerate (img )
457+ go .Heatmap (
458+ x = x , y = y , z = img [index_tup ], coloraxis = "coloraxis1" , name = str (i )
459+ )
460+ for i , index_tup in enumerate (itertools .product (* iterables ))
437461 ]
438462 else :
439463 traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
@@ -464,11 +488,21 @@ def imshow(
464488 _vectorize_zvalue (zmin , mode = "min" ),
465489 _vectorize_zvalue (zmax , mode = "max" ),
466490 )
491+ if slice_through :
492+ iterables = ()
493+ if animation_frame is not None :
494+ iterables += (range (nslices_animation ),)
495+ if facet_col is not None :
496+ iterables += (range (nslices_facet ),)
467497 if binary_string :
468498 if zmin is None and zmax is None : # no rescaling, faster
469499 img_rescaled = img
470500 rescale_image = False
471- elif img .ndim == 2 or (img .ndim == 3 and slice_through ):
501+ elif (
502+ img .ndim == 2
503+ or (img .ndim == 3 and slice_through )
504+ or (img .ndim == 4 and double_slice_through )
505+ ):
472506 img_rescaled = rescale_intensity (
473507 img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
474508 )
@@ -485,14 +519,15 @@ def imshow(
485519 axis = - 1 ,
486520 )
487521 if slice_through :
522+ tuples = [index_tup for index_tup in itertools .product (* iterables )]
488523 img_str = [
489524 _array_to_b64str (
490- img_rescaled_slice ,
525+ img_rescaled [ index_tup ] ,
491526 backend = binary_backend ,
492527 compression = binary_compression_level ,
493528 ext = binary_format ,
494529 )
495- for img_rescaled_slice in img_rescaled
530+ for index_tup in itertools . product ( * iterables )
496531 ]
497532
498533 else :
@@ -512,8 +547,10 @@ def imshow(
512547 colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
513548 if slice_through :
514549 traces = [
515- go .Image (z = img_slice , zmin = zmin , zmax = zmax , colormodel = colormodel )
516- for img_slice in img
550+ go .Image (
551+ z = img [index_tup ], zmin = zmin , zmax = zmax , colormodel = colormodel
552+ )
553+ for index_tup in itertools .product (* iterables )
517554 ]
518555 else :
519556 traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
@@ -533,9 +570,9 @@ def imshow(
533570 col_labels = []
534571 if facet_col is not None :
535572 slice_label = "slice" if labels .get ("slice" ) is None else labels ["slice" ]
536- if slices is None :
537- slices = range (nslices )
538- col_labels = ["%s = %d" % (slice_label , i ) for i in slices ]
573+ if facet_slices is None :
574+ facet_slices = range (nslices_facet )
575+ col_labels = ["%s = %d" % (slice_label , i ) for i in facet_slices ]
539576 fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
540577 layout_patch = dict ()
541578 for attr_name in ["height" , "width" ]:
@@ -547,11 +584,18 @@ def imshow(
547584 layout_patch ["margin" ] = {"t" : 60 }
548585
549586 frame_list = []
550- for index , ( slice_index , trace ) in enumerate (zip ( slices , traces ) ):
551- if facet_col or index == 0 :
587+ for index , trace in enumerate (traces ):
588+ if ( facet_col and index < nrows * ncols ) or index == 0 :
552589 fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
553- if animation_frame :
554- frame_list .append (dict (data = trace , layout = layout , name = str (slice_index )))
590+ if animation_frame is not None :
591+ for i in range (nslices_animation ):
592+ frame_list .append (
593+ dict (
594+ data = traces [nslices_facet * i : nslices_facet * (i + 1 )],
595+ layout = layout ,
596+ name = str (i ),
597+ )
598+ )
555599 if animation_frame :
556600 fig .frames = frame_list
557601 fig .update_layout (layout )
0 commit comments