Skip to content

xtensor: inline Einsum OFG in lower_dot to avoid ShapeFeature compile blow-up#2164

Open
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:xtensor-lower-dot-inline-einsum-ofg
Open

xtensor: inline Einsum OFG in lower_dot to avoid ShapeFeature compile blow-up#2164
cetagostini wants to merge 3 commits into
pymc-devs:mainfrom
cetagostini:xtensor-lower-dot-inline-einsum-ofg

Conversation

@cetagostini
Copy link
Copy Markdown
Contributor

Found while writing the tiny-transformer gallery notebook in #2163 — multi-layer xtensor attention with pytensor.grad blew up super-linearly in compile time, to the point of being unusable past 2 layers.

Root cause

pt.einsum wraps its output in an Einsum OpFromGraph. inline_optimized_einsum only inlines the OFG during specialize, but while it is alive ShapeFeature.on_import calls OpFromGraph.infer_shape on every node import during canonicalize, and infer_shape re-walks the inner graph each time. With many composed xtensor.dots (e.g. multi-layer attention) this becomes super-linear and dominates compile time.

cProfile of a 3-layer xtensor attention compile on main (32s total):

 0.753s  ShapeFeature.on_import
 0.466s  OpFromGraph.infer_shape           <-- the trap
 0.334s  OpFromGraph.local_traverse (walks inner graph)

This is the same family of ShapeFeature issues #2056 is targeting.

Fix

Inline the Einsum OFG immediately after building it in lower_dot, so ShapeFeature never sees it. The 2-operand Einsum produced by lower_dot has no path optimisation to defer, so inlining is safe and behaviour-preserving.

Reproducer

import time
import numpy as np
import pytensor, pytensor.tensor as pt, pytensor.xtensor as px
from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.xtensor.shape import stack as xstack

B, T, E, H, HD = 4, 32, 64, 4, 16
N_LAYERS = 3
rng = np.random.default_rng(0)

def attn(x):
    Wqkv = px.as_xtensor(
        pytensor.shared(rng.normal(size=(E, 3, H, HD))),
        dims=("embd", "qkv", "head", "hd"),
    )
    Wproj = px.as_xtensor(
        pytensor.shared(rng.normal(size=(E, E))),
        dims=("embd", "embd_out"),
    )
    qkv = px.dot(x, Wqkv, dim="embd")
    q = qkv.isel(qkv=0).rename(time="time_q")
    k = qkv.isel(qkv=1).rename(time="time_k")
    v = qkv.isel(qkv=2).rename(time="time_k")
    s = px.dot(q, k, dim="hd") / np.sqrt(HD)
    mask = px.as_xtensor(
        pt.tril(pt.ones((T, T), dtype="bool")), dims=("time_q", "time_k"),
    )
    a = px.math.softmax(px.where(mask, s, np.float64(-1e9)), dim="time_k")
    o = xstack(px.dot(a, v, dim="time_k"), embd=("head", "hd"))
    return px.dot(o, Wproj, dim="embd").rename(time_q="time", embd_out="embd")

x_t = pt.tensor("x", shape=(B, T, E))
x = px.as_xtensor(x_t, dims=("batch", "time", "embd"))
for _ in range(N_LAYERS):
    x = attn(x)
loss = rewrite_graph(x.values.sum(), include=("lower_xtensor",))
grad = pytensor.grad(loss, x_t)

t0 = time.time()
pytensor.function([x_t], [loss, grad])
print(f"compile: {time.time() - t0:.2f}s")

Compile time, single thread, M-class macOS, NUMBA mode:

N_LAYERS main this PR speedup
2 3.92s 0.66s
3 34.10s 1.06s 32×
4 338.01s 2.07s 163×

The PR adds a structural test that locks in the post-lower_xtensor invariant "no OpFromGraph nodes left in the lowered graph".

Relation to #2163

#2163 (tiny transformer gallery notebook) hits this in its current form and the notebook is unusable without it. Once this PR lands, #2163 will be updated to drop its own (incorrect) workaround attempts and rebase.

Test plan

  • tests/xtensor/ (non-random): 190 passed, 6 skipped, 1 xfailed (190 → 191 with new structural test)
  • tests/tensor/test_math.py::TestTensordot / TestMatMul: 24 passed, 1 xfailed
  • tests/tensor/test_einsum.py: 46 passed
  • Numerical: 2-layer xtensor vs equivalent plain attention → loss diff 0.000e+00, max grad diff 5.03e-17
  • Run the reproducer above on this branch and on main; confirm the speedup

Made with Cursor

cetagostini and others added 3 commits May 22, 2026 22:04
`pt.einsum` wraps its output in an `Einsum` `OpFromGraph`. The OFG is
only inlined by `inline_optimized_einsum` during `specialize`, but while
it is alive `ShapeFeature.on_import` calls `OpFromGraph.infer_shape` on
every node import during canonicalize, and `infer_shape` re-walks the
OFG's inner graph each time. When several xtensor dots are composed
(e.g. multi-layer attention), this becomes super-linear and dominates
compile time.

Inlining the OFG immediately after `einsum` removes it before any
shape-using pass ever sees it. The 2-operand case `lower_dot` produces
has no path optimisation to defer, so inlining is safe and behaviour-
preserving.

Effect on the toy multi-head attention reproducer
(block_size=32, n_embd=64, n_head=4, with grad):

    n_layer  plain    xtensor (before)   xtensor (after)
    1        0.94s    3.04s              1.12s
    2        2.03s    72.50s             4.07s

Adds a structural test that locks in the post-`lower_xtensor` invariant
"no OpFromGraph nodes left in the lowered graph". All existing
xtensor / tensordot / einsum tests still pass.

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
After inlining the Einsum OFG in lower_dot, the runtime shape mismatch
is now raised by the inlined Dot directly (`Shape mismatch: x has ...`)
instead of by np.einsum/np.dot inside the OFG's wrapper. Add that
message to the regex so the test passes on all backends.

Co-authored-by: Cursor <cursoragent@cursor.com>
@ricardoV94
Copy link
Copy Markdown
Member

Inline the Einsum OFG immediately after building it in lower_dot, so ShapeFeature never sees it. The 2-operand Einsum produced by lower_dot has no path optimisation to defer, so inlining is safe and behaviour-preserving.

The fix is clearly ShapeFeature, which I linked in the other PR as already been worked out. Did you try to see if simply including those changes fixed the problem?

@cetagostini
Copy link
Copy Markdown
Contributor Author

cetagostini commented May 26, 2026

Inline the Einsum OFG immediately after building it in lower_dot, so ShapeFeature never sees it. The 2-operand Einsum produced by lower_dot has no path optimisation to defer, so inlining is safe and behaviour-preserving.

The fix is clearly ShapeFeature, which I linked in the other PR as already been worked out. Did you try to see if simply including those changes fixed the problem?

Using your branch LLM takes compilation to 7s instead of 2s, and training 84s instead of 32s. So, the changes in #2056 are not sufficient to kill the full overhead generated. But this two lines of code make the trick @ricardoV94

@ricardoV94
Copy link
Copy Markdown
Member

ricardoV94 commented May 26, 2026

Inline the Einsum OFG immediately after building it in lower_dot, so ShapeFeature never sees it. The 2-operand Einsum produced by lower_dot has no path optimisation to defer, so inlining is safe and behaviour-preserving.

The fix is clearly ShapeFeature, which I linked in the other PR as already been worked out. Did you try to see if simply including those changes fixed the problem?

Using your branch LLM takes compilation to 7s instead of 2s, and training 84s instead of 32s. So, the changes in #2056 are not sufficient to kill the full overhead generated. But this two lines of code make the trick @ricardoV94

I'll have to look at the before/after graph, otherwise we're pushing the dirt under the carpet instead of fixing it.

If you want to, you would look at the dprint before and after, and use config optimizer_verbose or the interactive rewrite to see what's actually going on.

Also what's compile time? Calling function? first eval? All mix or some paths going through numba?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants