Skip to content

feat: add debugprint(file="rich") to return a rich.tree.Tree#2042

Open
williambdean wants to merge 8 commits intopymc-devs:v3from
williambdean:rich-print-graph
Open

feat: add debugprint(file="rich") to return a rich.tree.Tree#2042
williambdean wants to merge 8 commits intopymc-devs:v3from
williambdean:rich-print-graph

Conversation

@williambdean
Copy link
Copy Markdown
Contributor

Closes #1034

Adds file="rich" as a new sentinel to debugprint, returning a
rich.tree.Tree instead of printing text.

import pytensor.tensor as pt
from pytensor.printing import debugprint
from rich.console import Console

x = pt.vector("x")
tree = debugprint([x.mean(), x.std()], file="rich")

console = Console()
console.print(tree)
True_div [id A] 'mean'
├── Sum{axes=None} [id B]
│   └── x [id C]
└── Subtensor{i} [id D]
    ├── Cast{float64} [id E]
    │   └── Shape [id F]
    │       └── x [id C]
    └── 0 [id G]
Sqrt [id H] 'std'
└── True_div [id I] 'var'
    ├── Sum{axis=0} [id J]
    │   └── Pow [id K]
    │       ├── Sub [id L]
    │       │   ├── x [id C]
    │       │   └── True_div [id M] 'mean'
    │       │       ├── ExpandDims{axis=0} [id N]
    │       │       │   └── Sum{axis=0} [id O]
    │       │       │       └── x [id C]
    │       │       └── ExpandDims{axis=0} [id P]
    │       │           └── Subtensor{i} [id Q]
    │       │               ├── Cast{float64} [id R]
    │       │               │   └── Shape [id S]
    │       │               │       └── x [id C]
    │       │               └── 0 [id T]
    │       └── ExpandDims{axis=0} [id U]
    │           └── 2.0 [id V]
    └── Subtensor{i} [id W]
        ├── Cast{float64} [id X]
        │   └── Shape [id Y]
        │       └── Pow [id K] ···
        └── 0 [id Z]

rich is added as a hard dependency. The refactor also extracts
GraphNode, _assign_id, _build_label, and _iter_graph_nodes from
the monolithic _debugprint function, which both the text and rich
renderers now consume.

The "rich" sentinel slots in alongside "str" and None — the text
output path is unchanged.

@williambdean
Copy link
Copy Markdown
Contributor Author

williambdean commented Apr 10, 2026

The GraphNode / _iter_graph_nodes refactor also makes it straightforward to build other graph renderers on top of the same traversal. As a sketch, here is a Mermaid renderer using the same (x * 2).sum() graph from the tests:

Mermaid renderer sketch
import re
import pytensor.tensor as pt
from pytensor.printing import _build_label, _iter_graph_nodes

def build_mermaid(var) -> str:
    done: dict = {}
    used_ids: dict = {}
    node_defs: dict[str, str] = {}
    edges: list[tuple[str, str]] = []
    apply_to_mid: dict[int, str] = {}

    for gnode in _iter_graph_nodes(var, done=done):
        label = _build_label(
            gnode, done=done, used_ids=used_ids, id_type="CHAR",
            print_type=False, print_shape=False, print_destroy_map=False,
            print_view_map=False, print_op_info=False, op_information={},
        )
        mid = re.search(r"\[id ([^\]]+)\]", label).group(1)
        if mid not in node_defs:
            node_defs[mid] = label
        if gnode.parent_node is not None:
            parent_mid = apply_to_mid.get(id(gnode.parent_node))
            if parent_mid is not None:
                edges.append((parent_mid, mid))
        if gnode.var.owner is not None:
            apply_to_mid[id(gnode.var.owner)] = mid

    lines = ["graph TD"]
    for mid, label in node_defs.items():
        lines.append(f'    {mid}["{label.replace(chr(34), "#quot;")}"]')
    for src, dst in edges:
        lines.append(f"    {src} --> {dst}")
    return "\n".join(lines)

x = pt.dvector("x")
print(build_mermaid((x * 2).sum()))

Output:

graph TD
    A["Sum{axes=None} [id A]"]
    B["Mul [id B]"]
    C["x [id C]"]
    D["ExpandDims{axis=0} [id D]"]
    E["2 [id E]"]
    A --> B
    B --> C
    B --> D
    D --> E
Loading

This would also address #1488.

col_bars was iterating over all of ancestor_is_last, producing one extra
3-char column segment per node. The root-level ancestor entry should be
skipped (ancestor_is_last[1:]) because the root contributes an empty
prefix_child and column bar accumulation starts from depth-1 onwards.

This restores byte-for-byte compatibility with the pre-refactor text output.
The old approach marked a repeat/stop_on_name node itself as is_repeat=True,
replacing its label with ···. The new approach yields the node normally
(label visible) and then yields a separate sentinel child with is_repeat=True
one level deeper, so ··· appears indented below the node.

Also moves done[node] = "" to before yield gnode so DAG-diamond marking
is unconditional regardless of early generator abandonment.
@williambdean
Copy link
Copy Markdown
Contributor Author

williambdean commented Apr 10, 2026

Another one from the mermaid POC

x = pt.dvector("x")
y = (x * 2).sum()

z = pt.stack([y.mean(), y.std()])
graph TD
    A["MakeVector{dtype='float64'} [id A]"]
    B["True_div [id B] 'mean'"]
    C["Sum{axes=None} [id C]"]
    D["Sum{axes=None} [id D]"]
    E["Mul [id E]"]
    F["x [id F]"]
    G["ExpandDims{axis=0} [id G]"]
    H["2 [id H]"]
    I["Cast{float64} [id I]"]
    J["1 [id J]"]
    K["Sqrt [id K] 'std'"]
    L["True_div [id L] 'var'"]
    M["Sum{axes=[]} [id M]"]
    N["Pow [id N]"]
    O["Sub [id O]"]
    P["True_div [id P] 'mean'"]
    Q["Sum{axes=[]} [id Q]"]
    R["Cast{float64} [id R]"]
    S["1 [id S]"]
    T["2.0 [id T]"]
    U["Cast{float64} [id U]"]
    V["1 [id V]"]
    A --> B
    B --> C
    C --> D
    D --> E
    E --> F
    E --> G
    G --> H
    B --> I
    I --> J
    A --> K
    K --> L
    L --> M
    M --> N
    N --> O
    O --> D
    D --> D
    O --> P
    P --> Q
    Q --> D
    D --> D
    P --> R
    R --> S
    N --> T
    L --> U
    U --> V
Loading

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Explore textual representation with rich / textualize libraries

2 participants