@@ -314,8 +314,9 @@ def imshow(
314314 animation_frame = img .dims .index (animation_frame )
315315 nslices_animation = img .shape [animation_frame ]
316316 animation_slices = range (nslices_animation )
317- slice_through = (facet_col is not None ) or (animation_frame is not None )
318- double_slice_through = (facet_col is not None ) and (animation_frame is not None )
317+ slice_dimensions = (facet_col is not None ) + (
318+ animation_frame is not None
319+ ) # 0, 1, or 2
319320 facet_label = None
320321 animation_label = None
321322 # ----- Define x and y, set labels if img is an xarray -------------------
@@ -344,10 +345,10 @@ def imshow(
344345 labels ["x" ] = x_label
345346 if labels .get ("y" , None ) is None :
346347 labels ["y" ] = y_label
347- if labels .get ("animation_slice " , None ) is None :
348- labels ["animation_slice " ] = animation_label
349- if labels .get ("facet_slice " , None ) is None :
350- labels ["facet_slice " ] = facet_label
348+ if labels .get ("animation " , None ) is None :
349+ labels ["animation " ] = animation_label
350+ if labels .get ("facet " , None ) is None :
351+ labels ["facet " ] = facet_label
351352 if labels .get ("color" , None ) is None :
352353 labels ["color" ] = xarray .plot .utils .label_from_attrs (img )
353354 labels ["color" ] = labels ["color" ].replace ("\n " , "<br>" )
@@ -382,32 +383,27 @@ def imshow(
382383
383384 # --------------- Starting from here img is always a numpy array --------
384385 img = np .asanyarray (img )
386+ # Reshape array so that animation dimension comes first, then facets, then images
385387 if facet_col is not None :
386388 img = np .moveaxis (img , facet_col , 0 )
387- print (img .shape )
388389 if animation_frame is not None and animation_frame < facet_col :
389390 animation_frame += 1
390391 facet_col = True
391392 if animation_frame is not None :
392393 img = np .moveaxis (img , animation_frame , 0 )
393- print (img .shape )
394394 animation_frame = True
395- args ["animation_frame" ] = ( # TODO
396- "slice" if labels .get ("slice " ) is None else labels ["slice " ]
395+ args ["animation_frame" ] = (
396+ "slice" if labels .get ("animation " ) is None else labels ["animation " ]
397397 )
398398 iterables = ()
399- if slice_through :
400- if animation_frame is not None :
401- iterables += (range (nslices_animation ),)
402- if facet_col is not None :
403- iterables += (range (nslices_facet ),)
399+ if animation_frame is not None :
400+ iterables += (range (nslices_animation ),)
401+ if facet_col is not None :
402+ iterables += (range (nslices_facet ),)
404403
405404 # Default behaviour of binary_string: True for RGB images, False for 2D
406405 if binary_string is None :
407- if slice_through :
408- binary_string = img .ndim >= 4 and not is_dataframe
409- else :
410- binary_string = img .ndim >= 3 and not is_dataframe
406+ binary_string = img .ndim >= (3 + slice_dimensions ) and not is_dataframe
411407
412408 # Cast bools to uint8 (also one byte)
413409 if img .dtype == np .bool :
@@ -419,11 +415,7 @@ def imshow(
419415
420416 # -------- Contrast rescaling: either minmax or infer ------------------
421417 if contrast_rescaling is None :
422- contrast_rescaling = (
423- "minmax"
424- if (img .ndim == 2 or (img .ndim == 3 and slice_through ))
425- else "infer"
426- )
418+ contrast_rescaling = "minmax" if img .ndim == (2 + slice_dimensions ) else "infer"
427419
428420 # We try to set zmin and zmax only if necessary, because traces have good defaults
429421 if contrast_rescaling == "minmax" :
@@ -439,19 +431,15 @@ def imshow(
439431 if zmin is None and zmax is not None :
440432 zmin = 0
441433
442- # For 2d data, use Heatmap trace, unless binary_string is True
443- if (
444- img .ndim == 2
445- or (img .ndim == 3 and slice_through )
446- or (img .ndim == 4 and double_slice_through )
447- ) and not binary_string :
448- y_index = 1 if slice_through else 0
434+ # For 2d data, use Heatmap trace, unless binary_string is True
435+ if img .ndim == 2 + slice_dimensions and not binary_string :
436+ y_index = slice_dimensions
449437 if y is not None and img .shape [y_index ] != len (y ):
450438 raise ValueError (
451439 "The length of the y vector must match the length of the first "
452440 + "dimension of the img matrix."
453441 )
454- x_index = 2 if slice_through else 1
442+ x_index = slice_dimensions + 1
455443 if x is not None and img .shape [x_index ] != len (x ):
456444 raise ValueError (
457445 "The length of the x vector must match the length of the second "
@@ -480,7 +468,8 @@ def imshow(
480468
481469 # For 2D+RGB data, use Image trace
482470 elif (
483- img .ndim >= 3 and (img .shape [- 1 ] in [3 , 4 ] or slice_through and binary_string )
471+ img .ndim >= 3
472+ and (img .shape [- 1 ] in [3 , 4 ] or slice_dimensions and binary_string )
484473 ) or (img .ndim == 2 and binary_string ):
485474 rescale_image = True # to check whether image has been modified
486475 if zmin is not None and zmax is not None :
@@ -492,11 +481,7 @@ def imshow(
492481 if zmin is None and zmax is None : # no rescaling, faster
493482 img_rescaled = img
494483 rescale_image = False
495- elif (
496- img .ndim == 2
497- or (img .ndim == 3 and slice_through )
498- or (img .ndim == 4 and double_slice_through )
499- ):
484+ elif img .ndim == 2 + slice_dimensions : # single-channel image
500485 img_rescaled = rescale_intensity (
501486 img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
502487 )
@@ -547,9 +532,7 @@ def imshow(
547532 # Now build figure
548533 col_labels = []
549534 if facet_col is not None :
550- slice_label = "slice" if labels .get ("slice" ) is None else labels ["slice" ]
551- if facet_slices is None :
552- facet_slices = range (nslices_facet )
535+ slice_label = "slice" if labels .get ("facet" ) is None else labels ["facet" ]
553536 col_labels = ["%s = %d" % (slice_label , i ) for i in facet_slices ]
554537 fig = init_figure (args , "xy" , [], nrows , ncols , col_labels , [])
555538 layout_patch = dict ()
@@ -566,12 +549,12 @@ def imshow(
566549 if (facet_col and index < nrows * ncols ) or index == 0 :
567550 fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
568551 if animation_frame is not None :
569- for i in range (nslices_animation ):
552+ for i , index in zip ( range (nslices_animation ), animation_slices ):
570553 frame_list .append (
571554 dict (
572555 data = traces [nslices_facet * i : nslices_facet * (i + 1 )],
573556 layout = layout ,
574- name = str (i ),
557+ name = str (index ),
575558 )
576559 )
577560 if animation_frame :
@@ -607,5 +590,5 @@ def imshow(
607590 if labels ["y" ]:
608591 fig .update_yaxes (title_text = labels ["y" ])
609592 configure_animation_controls (args , go .Image , fig )
610- # fig.update_layout(template=args["template"], overwrite=True)
593+ fig .update_layout (template = args ["template" ], overwrite = True )
611594 return fig
0 commit comments