Skip to content

Commit a9d6247

Browse files
FBumannclaude
andcommitted
feat: support faceted figures in subplots()
Rewrote subplots() to use manual axis domain management instead of make_subplots. Each figure's internal axes are remapped with scaled domains to fit within the grid cell, so faceted figures now work. Updated notebook with faceted subplots example. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dda84c7 commit a9d6247

File tree

3 files changed

+187
-66
lines changed

3 files changed

+187
-66
lines changed

docs/examples/combining.ipynb

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -539,9 +539,9 @@
539539
"id": "34",
540540
"metadata": {},
541541
"source": [
542-
"### Limitations\n",
542+
"### With Facets\n",
543543
"\n",
544-
"`subplots` requires single-panel figures — faceted and animated figures are not supported."
544+
"Faceted figures can be composed — each figure's internal subplots are remapped into the grid cell."
545545
]
546546
},
547547
{
@@ -551,19 +551,13 @@
551551
"metadata": {},
552552
"outputs": [],
553553
"source": [
554-
"# Faceted figure → rejected\n",
555-
"faceted = xpx(population).line(facet_col=\"country\")\n",
556-
"try:\n",
557-
" subplots(faceted)\n",
558-
"except ValueError as e:\n",
559-
" print(f\"ValueError: {e}\")\n",
554+
"# Faceted bar on top, faceted line below\n",
555+
"pop_faceted = xpx(population).bar(facet_col=\"country\")\n",
556+
"gdp_faceted = xpx(gdp_per_capita).line(facet_col=\"country\")\n",
560557
"\n",
561-
"# Animated figure → rejected\n",
562-
"animated = xpx(population).bar(animation_frame=\"country\")\n",
563-
"try:\n",
564-
" subplots(animated)\n",
565-
"except ValueError as e:\n",
566-
" print(f\"ValueError: {e}\")"
558+
"grid = subplots(pop_faceted, gdp_faceted, cols=1)\n",
559+
"grid.update_layout(height=600, showlegend=False)\n",
560+
"grid"
567561
]
568562
},
569563
{
@@ -757,7 +751,7 @@
757751
"|----------|--------|-----------|-------------------|\n",
758752
"| `overlay` | Yes (must match) | Yes (frames must match) | Static overlay on animated base OK |\n",
759753
"| `add_secondary_y` | Yes (must match) | Yes (frames must match) | Static secondary on animated base OK |\n",
760-
"| `subplots` | No (single-panel only) | No | N/A |"
754+
"| `subplots` | Yes (remapped into cells) | No | N/A |"
761755
]
762756
}
763757
],

tests/test_figures.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def test_titles_from_figure_title(self) -> None:
820820
fig2 = xpx(da).bar(title="Other Title")
821821
grid = subplots(fig1, fig2, cols=2)
822822
titles = [ann.text for ann in grid.layout.annotations]
823-
assert titles == ["My Title", "Other Title"]
823+
assert titles == ["<b>My Title</b>", "<b>Other Title</b>"]
824824

825825
def test_titles_from_yaxis_label(self) -> None:
826826
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature")
@@ -829,7 +829,7 @@ def test_titles_from_yaxis_label(self) -> None:
829829
fig2 = xpx(da2).line()
830830
grid = subplots(fig1, fig2, cols=2)
831831
titles = [ann.text for ann in grid.layout.annotations]
832-
assert titles == ["Temperature", "Pressure"]
832+
assert titles == ["<b>Temperature</b>", "<b>Pressure</b>"]
833833

834834
def test_titles_fallback_empty(self) -> None:
835835
grid = subplots(go.Figure(), go.Figure(), cols=2)
@@ -860,15 +860,21 @@ def test_invalid_cols_raises(self) -> None:
860860
with pytest.raises(ValueError, match="cols must be >= 1"):
861861
subplots(go.Figure(), cols=0)
862862

863-
def test_faceted_figure_raises(self) -> None:
863+
def test_faceted_figures_stacked(self) -> None:
864+
"""Faceted figures can be stacked in a subplot grid."""
864865
da = xr.DataArray(
865866
np.random.rand(10, 3),
866867
dims=["x", "facet"],
867868
coords={"facet": ["A", "B", "C"]},
868869
)
869-
fig = xpx(da).line(facet_col="facet")
870-
with pytest.raises(ValueError, match="internal subplots"):
871-
subplots(fig)
870+
fig1 = xpx(da).bar(facet_col="facet")
871+
fig2 = xpx(da).line(facet_col="facet")
872+
grid = subplots(fig1, fig2, cols=1)
873+
# 3 bar traces + 3 line traces
874+
assert len(grid.data) == 6
875+
# All traces should have unique axis assignments
876+
axes = {(t.xaxis, t.yaxis) for t in grid.data}
877+
assert len(axes) == 6
872878

873879
def test_animated_figure_raises(self) -> None:
874880
da = xr.DataArray(np.random.rand(10, 3), dims=["x", "time"])

xarray_plotly/figures.py

Lines changed: 166 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,9 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
624624
"""Arrange multiple figures into a subplot grid.
625625
626626
Creates a new figure with each input figure placed in its own cell.
627-
Subplot titles are derived from each figure's title or y-axis label.
627+
Figures may contain internal subplots (facets) — their axes are remapped
628+
to fit within the grid cell. Subplot titles are derived from each
629+
figure's title or y-axis label.
628630
629631
Args:
630632
*figs: One or more Plotly figures to arrange.
@@ -635,7 +637,7 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
635637
636638
Raises:
637639
ValueError: If no figures are provided, cols < 1, or a figure has
638-
internal subplots (facets) or animation frames.
640+
animation frames.
639641
640642
Example:
641643
>>> import numpy as np
@@ -650,48 +652,72 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
650652
"""
651653
import math
652654

653-
from plotly.subplots import make_subplots
655+
import plotly.graph_objects as go
654656

655657
if not figs:
656658
raise ValueError("At least one figure is required.")
657659
if cols < 1:
658660
raise ValueError(f"cols must be >= 1, got {cols}.")
659661

660-
# Validate inputs
661662
for i, fig in enumerate(figs):
662-
axes = _get_subplot_axes(fig)
663-
if len(axes) > 1:
664-
raise ValueError(
665-
f"Figure at position {i} has internal subplots (facets). "
666-
"Use single-panel figures with subplots()."
667-
)
668663
if fig.frames:
669664
raise ValueError(
670665
f"Figure at position {i} has animation frames. "
671666
"Animated figures are not supported in subplots()."
672667
)
673668

674669
rows = math.ceil(len(figs) / cols)
670+
combined = go.Figure()
675671

676-
# Derive subplot titles
677-
titles = [_get_figure_title(f) for f in figs]
678-
# Pad for empty trailing cells
679-
titles.extend("" for _ in range(rows * cols - len(figs)))
672+
# Grid spacing
673+
h_gap = 0.05
674+
v_gap = 0.08
675+
cell_w = (1.0 - h_gap * (cols - 1)) / cols
676+
cell_h = (1.0 - v_gap * (rows - 1)) / rows
680677

681-
grid = make_subplots(rows=rows, cols=cols, subplot_titles=titles)
678+
next_x_num = 1
679+
next_y_num = 1
682680

683-
# Add traces from each figure to the correct cell
684681
for i, fig in enumerate(figs):
685-
row = i // cols + 1
686-
col = i % cols + 1
682+
row = i // cols # 0-indexed, top to bottom
683+
col = i % cols
684+
685+
# Cell boundaries (clamped to [0, 1])
686+
cell_x0 = max(0.0, col * (cell_w + h_gap))
687+
cell_x1 = min(1.0, cell_x0 + cell_w)
688+
cell_y1 = min(1.0, 1.0 - row * (cell_h + v_gap)) # top-down
689+
cell_y0 = max(0.0, cell_y1 - cell_h)
690+
691+
# Build axis remapping: old axis ref → new axis ref
692+
axis_map, next_x_num, next_y_num = _remap_figure_axes(
693+
fig, combined, next_x_num, next_y_num, cell_x0, cell_x1, cell_y0, cell_y1
694+
)
687695

696+
# Add traces with remapped axis refs
688697
for trace in fig.data:
689-
grid.add_trace(copy.deepcopy(trace), row=row, col=col)
690-
691-
# Copy axis config from source figure to target cell
692-
_copy_axis_config(fig, grid, row, col)
698+
tc = copy.deepcopy(trace)
699+
old_x = getattr(tc, "xaxis", None) or "x"
700+
old_y = getattr(tc, "yaxis", None) or "y"
701+
tc.xaxis = axis_map[old_x]["new_x"]
702+
tc.yaxis = axis_map[old_y]["new_y"]
703+
combined.add_trace(tc)
704+
705+
# Add subplot title as annotation
706+
title = _get_figure_title(fig)
707+
if title:
708+
combined.add_annotation(
709+
text=f"<b>{title}</b>",
710+
x=(cell_x0 + cell_x1) / 2,
711+
y=cell_y1,
712+
xref="paper",
713+
yref="paper",
714+
xanchor="center",
715+
yanchor="bottom",
716+
showarrow=False,
717+
font={"size": 14},
718+
)
693719

694-
return grid
720+
return combined
695721

696722

697723
# Axis properties safe to copy between figures (display-only, not structural).
@@ -712,37 +738,132 @@ def subplots(*figs: go.Figure, cols: int = 1) -> go.Figure:
712738
"zeroline",
713739
"zerolinecolor",
714740
"zerolinewidth",
741+
"showticklabels",
715742
)
716743

717744

718-
def _copy_axis_config(src: go.Figure, grid: go.Figure, row: int, col: int) -> None:
719-
"""Copy display-related axis properties from a source figure to a grid cell.
745+
def _axis_layout_key(ref: str) -> str:
746+
"""Convert axis reference to layout property name.
720747
721-
Args:
722-
src: Source figure whose axis config to copy.
723-
grid: Target subplot grid figure.
724-
row: Target row (1-indexed).
725-
col: Target column (1-indexed).
748+
``"x"`` → ``"xaxis"``, ``"x2"`` → ``"xaxis2"``,
749+
``"y"`` → ``"yaxis"``, ``"y3"`` → ``"yaxis3"``.
726750
"""
727-
# Get the xaxis/yaxis objects for the target cell
728-
xref, yref = grid.get_subplot(row, col)
751+
if ref in ("x", "y"):
752+
return f"{ref}axis"
753+
prefix = ref[0] # "x" or "y"
754+
num = ref[1:]
755+
return f"{prefix}axis{num}"
729756

730-
# Convert plotly axis objects to layout property names
731-
# xref.plotly_name is e.g. "xaxis" or "xaxis2"
732-
x_layout_key = xref.plotly_name
733-
y_layout_key = yref.plotly_name
734757

735-
src_xaxis = src.layout.xaxis or {}
736-
src_yaxis = src.layout.yaxis or {}
758+
def _new_axis_ref(prefix: str, num: int) -> str:
759+
"""Build an axis reference string. ``_new_axis_ref("x", 1)`` → ``"x"``, ``("x", 3)`` → ``"x3"``."""
760+
return prefix if num == 1 else f"{prefix}{num}"
737761

738-
for prop in _AXIS_PROPS_TO_COPY:
739-
xval = getattr(src_xaxis, prop, None)
740-
if xval is not None:
741-
grid.layout[x_layout_key][prop] = xval
742762

743-
yval = getattr(src_yaxis, prop, None)
744-
if yval is not None:
745-
grid.layout[y_layout_key][prop] = yval
763+
def _remap_figure_axes(
764+
fig: go.Figure,
765+
combined: go.Figure,
766+
next_x_num: int,
767+
next_y_num: int,
768+
cell_x0: float,
769+
cell_x1: float,
770+
cell_y0: float,
771+
cell_y1: float,
772+
) -> tuple[dict[str, dict[str, str]], int, int]:
773+
"""Remap a figure's axes into a grid cell, adding axis configs to the combined layout.
774+
775+
Args:
776+
fig: Source figure.
777+
combined: Target combined figure (mutated — axis configs added to layout).
778+
next_x_num: Next available x-axis number.
779+
next_y_num: Next available y-axis number.
780+
cell_x0, cell_x1: Horizontal cell bounds in paper coordinates.
781+
cell_y0, cell_y1: Vertical cell bounds in paper coordinates.
782+
783+
Returns:
784+
Tuple of (axis_map, next_x_num, next_y_num).
785+
axis_map maps old axis refs to ``{"new_x": ...}`` or ``{"new_y": ...}``.
786+
"""
787+
cell_w = cell_x1 - cell_x0
788+
cell_h = cell_y1 - cell_y0
789+
src_layout = fig.layout.to_plotly_json()
790+
791+
x_remap: dict[str, str] = {}
792+
y_remap: dict[str, str] = {}
793+
794+
# Get all unique axis refs
795+
x_refs: set[str] = set()
796+
y_refs: set[str] = set()
797+
for trace in fig.data:
798+
x_refs.add(getattr(trace, "xaxis", None) or "x")
799+
y_refs.add(getattr(trace, "yaxis", None) or "y")
800+
801+
# Remap x-axes
802+
for old_xref in sorted(x_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1):
803+
new_xref = _new_axis_ref("x", next_x_num)
804+
x_remap[old_xref] = new_xref
805+
806+
src_config = src_layout.get(_axis_layout_key(old_xref), {})
807+
src_domain = src_config.get("domain", [0.0, 1.0])
808+
new_domain = [
809+
max(0.0, cell_x0 + src_domain[0] * cell_w),
810+
min(1.0, cell_x0 + src_domain[1] * cell_w),
811+
]
812+
813+
new_config: dict[str, Any] = {"domain": new_domain}
814+
for prop in _AXIS_PROPS_TO_COPY:
815+
if prop in src_config:
816+
new_config[prop] = src_config[prop]
817+
818+
combined.layout[_axis_layout_key(new_xref)] = new_config
819+
next_x_num += 1
820+
821+
# Remap y-axes
822+
for old_yref in sorted(y_refs, key=lambda r: int(r[1:]) if len(r) > 1 else 1):
823+
new_yref = _new_axis_ref("y", next_y_num)
824+
y_remap[old_yref] = new_yref
825+
826+
src_config = src_layout.get(_axis_layout_key(old_yref), {})
827+
src_domain = src_config.get("domain", [0.0, 1.0])
828+
new_domain = [
829+
max(0.0, cell_y0 + src_domain[0] * cell_h),
830+
min(1.0, cell_y0 + src_domain[1] * cell_h),
831+
]
832+
833+
new_config = {"domain": new_domain}
834+
for prop in _AXIS_PROPS_TO_COPY:
835+
if prop in src_config:
836+
new_config[prop] = src_config[prop]
837+
838+
combined.layout[_axis_layout_key(new_yref)] = new_config
839+
next_y_num += 1
840+
841+
# Set anchors between paired axes
842+
for trace in fig.data:
843+
old_x = getattr(trace, "xaxis", None) or "x"
844+
old_y = getattr(trace, "yaxis", None) or "y"
845+
combined.layout[_axis_layout_key(x_remap[old_x])]["anchor"] = y_remap[old_y]
846+
combined.layout[_axis_layout_key(y_remap[old_y])]["anchor"] = x_remap[old_x]
847+
848+
# Propagate matches relationships
849+
for old_ref, new_ref in x_remap.items():
850+
src_config = src_layout.get(_axis_layout_key(old_ref), {})
851+
if "matches" in src_config and src_config["matches"] in x_remap:
852+
combined.layout[_axis_layout_key(new_ref)]["matches"] = x_remap[src_config["matches"]]
853+
854+
for old_ref, new_ref in y_remap.items():
855+
src_config = src_layout.get(_axis_layout_key(old_ref), {})
856+
if "matches" in src_config and src_config["matches"] in y_remap:
857+
combined.layout[_axis_layout_key(new_ref)]["matches"] = y_remap[src_config["matches"]]
858+
859+
# Build combined return mapping
860+
result: dict[str, dict[str, str]] = {}
861+
for old_x, new_x in x_remap.items():
862+
result[old_x] = {"new_x": new_x}
863+
for old_y, new_y in y_remap.items():
864+
result[old_y] = {"new_y": new_y}
865+
866+
return result, next_x_num, next_y_num
746867

747868

748869
def update_traces(

0 commit comments

Comments
 (0)