@@ -624,7 +624,9 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
624624 """Arrange multiple figures into a subplot grid.
625625
626626 Creates a new figure with each input figure placed in its own cell.
627- Subplot titles are derived from each figure's title or y-axis label.
627+ Figures may contain internal subplots (facets) — their axes are remapped
628+ to fit within the grid cell. Subplot titles are derived from each
629+ figure's title or y-axis label.
628630
629631 Args:
630632 *figs: One or more Plotly figures to arrange.
@@ -635,7 +637,7 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
635637
636638 Raises:
637639 ValueError: If no figures are provided, cols < 1, or a figure has
638- internal subplots (facets) or animation frames.
640+ animation frames.
639641
640642 Example:
641643 >>> import numpy as np
@@ -650,48 +652,72 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
650652 """
651653 import math
652654
653- from plotly .subplots import make_subplots
655+ import plotly .graph_objects as go
654656
655657 if not figs :
656658 raise ValueError ("At least one figure is required." )
657659 if cols < 1 :
658660 raise ValueError (f"cols must be >= 1, got { cols } ." )
659661
660- # Validate inputs
661662 for i , fig in enumerate (figs ):
662- axes = _get_subplot_axes (fig )
663- if len (axes ) > 1 :
664- raise ValueError (
665- f"Figure at position { i } has internal subplots (facets). "
666- "Use single-panel figures with subplots()."
667- )
668663 if fig .frames :
669664 raise ValueError (
670665 f"Figure at position { i } has animation frames. "
671666 "Animated figures are not supported in subplots()."
672667 )
673668
674669 rows = math .ceil (len (figs ) / cols )
670+ combined = go .Figure ()
675671
676- # Derive subplot titles
677- titles = [_get_figure_title (f ) for f in figs ]
678- # Pad for empty trailing cells
679- titles .extend ("" for _ in range (rows * cols - len (figs )))
672+ # Grid spacing
673+ h_gap = 0.05
674+ v_gap = 0.08
675+ cell_w = (1.0 - h_gap * (cols - 1 )) / cols
676+ cell_h = (1.0 - v_gap * (rows - 1 )) / rows
680677
681- grid = make_subplots (rows = rows , cols = cols , subplot_titles = titles )
678+ next_x_num = 1
679+ next_y_num = 1
682680
683- # Add traces from each figure to the correct cell
684681 for i , fig in enumerate (figs ):
685- row = i // cols + 1
686- col = i % cols + 1
682+ row = i // cols # 0-indexed, top to bottom
683+ col = i % cols
684+
685+ # Cell boundaries (clamped to [0, 1])
686+ cell_x0 = max (0.0 , col * (cell_w + h_gap ))
687+ cell_x1 = min (1.0 , cell_x0 + cell_w )
688+ cell_y1 = min (1.0 , 1.0 - row * (cell_h + v_gap )) # top-down
689+ cell_y0 = max (0.0 , cell_y1 - cell_h )
690+
691+ # Build axis remapping: old axis ref → new axis ref
692+ axis_map , next_x_num , next_y_num = _remap_figure_axes (
693+ fig , combined , next_x_num , next_y_num , cell_x0 , cell_x1 , cell_y0 , cell_y1
694+ )
687695
696+ # Add traces with remapped axis refs
688697 for trace in fig .data :
689- grid .add_trace (copy .deepcopy (trace ), row = row , col = col )
690-
691- # Copy axis config from source figure to target cell
692- _copy_axis_config (fig , grid , row , col )
698+ tc = copy .deepcopy (trace )
699+ old_x = getattr (tc , "xaxis" , None ) or "x"
700+ old_y = getattr (tc , "yaxis" , None ) or "y"
701+ tc .xaxis = axis_map [old_x ]["new_x" ]
702+ tc .yaxis = axis_map [old_y ]["new_y" ]
703+ combined .add_trace (tc )
704+
705+ # Add subplot title as annotation
706+ title = _get_figure_title (fig )
707+ if title :
708+ combined .add_annotation (
709+ text = f"<b>{ title } </b>" ,
710+ x = (cell_x0 + cell_x1 ) / 2 ,
711+ y = cell_y1 ,
712+ xref = "paper" ,
713+ yref = "paper" ,
714+ xanchor = "center" ,
715+ yanchor = "bottom" ,
716+ showarrow = False ,
717+ font = {"size" : 14 },
718+ )
693719
694- return grid
720+ return combined
695721
696722
697723# Axis properties safe to copy between figures (display-only, not structural).
@@ -712,37 +738,132 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
712738 "zeroline" ,
713739 "zerolinecolor" ,
714740 "zerolinewidth" ,
741+ "showticklabels" ,
715742)
716743
717744
718- def _copy_axis_config ( src : go . Figure , grid : go . Figure , row : int , col : int ) -> None :
719- """Copy display-related axis properties from a source figure to a grid cell .
745+ def _axis_layout_key ( ref : str ) -> str :
746+ """Convert axis reference to layout property name .
720747
721- Args:
722- src: Source figure whose axis config to copy.
723- grid: Target subplot grid figure.
724- row: Target row (1-indexed).
725- col: Target column (1-indexed).
748+ ``"x"`` → ``"xaxis"``, ``"x2"`` → ``"xaxis2"``,
749+ ``"y"`` → ``"yaxis"``, ``"y3"`` → ``"yaxis3"``.
726750 """
727- # Get the xaxis/yaxis objects for the target cell
728- xref , yref = grid .get_subplot (row , col )
751+ if ref in ("x" , "y" ):
752+ return f"{ ref } axis"
753+ prefix = ref [0 ] # "x" or "y"
754+ num = ref [1 :]
755+ return f"{ prefix } axis{ num } "
729756
730- # Convert plotly axis objects to layout property names
731- # xref.plotly_name is e.g. "xaxis" or "xaxis2"
732- x_layout_key = xref .plotly_name
733- y_layout_key = yref .plotly_name
734757
735- src_xaxis = src .layout .xaxis or {}
736- src_yaxis = src .layout .yaxis or {}
758+ def _new_axis_ref (prefix : str , num : int ) -> str :
759+ """Build an axis reference string. ``_new_axis_ref("x", 1)`` → ``"x"``, ``("x", 3)`` → ``"x3"``."""
760+ return prefix if num == 1 else f"{ prefix } { num } "
737761
738- for prop in _AXIS_PROPS_TO_COPY :
739- xval = getattr (src_xaxis , prop , None )
740- if xval is not None :
741- grid .layout [x_layout_key ][prop ] = xval
742762
743- yval = getattr (src_yaxis , prop , None )
744- if yval is not None :
745- grid .layout [y_layout_key ][prop ] = yval
763+ def _remap_figure_axes (
764+ fig : go .Figure ,
765+ combined : go .Figure ,
766+ next_x_num : int ,
767+ next_y_num : int ,
768+ cell_x0 : float ,
769+ cell_x1 : float ,
770+ cell_y0 : float ,
771+ cell_y1 : float ,
772+ ) -> tuple [dict [str , dict [str , str ]], int , int ]:
773+ """Remap a figure's axes into a grid cell, adding axis configs to the combined layout.
774+
775+ Args:
776+ fig: Source figure.
777+ combined: Target combined figure (mutated — axis configs added to layout).
778+ next_x_num: Next available x-axis number.
779+ next_y_num: Next available y-axis number.
780+ cell_x0, cell_x1: Horizontal cell bounds in paper coordinates.
781+ cell_y0, cell_y1: Vertical cell bounds in paper coordinates.
782+
783+ Returns:
784+ Tuple of (axis_map, next_x_num, next_y_num).
785+ axis_map maps old axis refs to ``{"new_x": ...}`` or ``{"new_y": ...}``.
786+ """
787+ cell_w = cell_x1 - cell_x0
788+ cell_h = cell_y1 - cell_y0
789+ src_layout = fig .layout .to_plotly_json ()
790+
791+ x_remap : dict [str , str ] = {}
792+ y_remap : dict [str , str ] = {}
793+
794+ # Get all unique axis refs
795+ x_refs : set [str ] = set ()
796+ y_refs : set [str ] = set ()
797+ for trace in fig .data :
798+ x_refs .add (getattr (trace , "xaxis" , None ) or "x" )
799+ y_refs .add (getattr (trace , "yaxis" , None ) or "y" )
800+
801+ # Remap x-axes
802+ for old_xref in sorted (x_refs , key = lambda r : int (r [1 :]) if len (r ) > 1 else 1 ):
803+ new_xref = _new_axis_ref ("x" , next_x_num )
804+ x_remap [old_xref ] = new_xref
805+
806+ src_config = src_layout .get (_axis_layout_key (old_xref ), {})
807+ src_domain = src_config .get ("domain" , [0.0 , 1.0 ])
808+ new_domain = [
809+ max (0.0 , cell_x0 + src_domain [0 ] * cell_w ),
810+ min (1.0 , cell_x0 + src_domain [1 ] * cell_w ),
811+ ]
812+
813+ new_config : dict [str , Any ] = {"domain" : new_domain }
814+ for prop in _AXIS_PROPS_TO_COPY :
815+ if prop in src_config :
816+ new_config [prop ] = src_config [prop ]
817+
818+ combined .layout [_axis_layout_key (new_xref )] = new_config
819+ next_x_num += 1
820+
821+ # Remap y-axes
822+ for old_yref in sorted (y_refs , key = lambda r : int (r [1 :]) if len (r ) > 1 else 1 ):
823+ new_yref = _new_axis_ref ("y" , next_y_num )
824+ y_remap [old_yref ] = new_yref
825+
826+ src_config = src_layout .get (_axis_layout_key (old_yref ), {})
827+ src_domain = src_config .get ("domain" , [0.0 , 1.0 ])
828+ new_domain = [
829+ max (0.0 , cell_y0 + src_domain [0 ] * cell_h ),
830+ min (1.0 , cell_y0 + src_domain [1 ] * cell_h ),
831+ ]
832+
833+ new_config = {"domain" : new_domain }
834+ for prop in _AXIS_PROPS_TO_COPY :
835+ if prop in src_config :
836+ new_config [prop ] = src_config [prop ]
837+
838+ combined .layout [_axis_layout_key (new_yref )] = new_config
839+ next_y_num += 1
840+
841+ # Set anchors between paired axes
842+ for trace in fig .data :
843+ old_x = getattr (trace , "xaxis" , None ) or "x"
844+ old_y = getattr (trace , "yaxis" , None ) or "y"
845+ combined .layout [_axis_layout_key (x_remap [old_x ])]["anchor" ] = y_remap [old_y ]
846+ combined .layout [_axis_layout_key (y_remap [old_y ])]["anchor" ] = x_remap [old_x ]
847+
848+ # Propagate matches relationships
849+ for old_ref , new_ref in x_remap .items ():
850+ src_config = src_layout .get (_axis_layout_key (old_ref ), {})
851+ if "matches" in src_config and src_config ["matches" ] in x_remap :
852+ combined .layout [_axis_layout_key (new_ref )]["matches" ] = x_remap [src_config ["matches" ]]
853+
854+ for old_ref , new_ref in y_remap .items ():
855+ src_config = src_layout .get (_axis_layout_key (old_ref ), {})
856+ if "matches" in src_config and src_config ["matches" ] in y_remap :
857+ combined .layout [_axis_layout_key (new_ref )]["matches" ] = y_remap [src_config ["matches" ]]
858+
859+ # Build combined return mapping
860+ result : dict [str , dict [str , str ]] = {}
861+ for old_x , new_x in x_remap .items ():
862+ result [old_x ] = {"new_x" : new_x }
863+ for old_y , new_y in y_remap .items ():
864+ result [old_y ] = {"new_y" : new_y }
865+
866+ return result , next_x_num , next_y_num
746867
747868
748869def update_traces (
0 commit comments