Skip to content

perf(bb/msm): Jacobian (S, W) tree bucket reduction#23526

Draft
AztecBot wants to merge 13 commits into
zw/msm-webgpu-experiments-v2from
cb/cc66f83b5de3
Draft

perf(bb/msm): Jacobian (S, W) tree bucket reduction#23526
AztecBot wants to merge 13 commits into
zw/msm-webgpu-experiments-v2from
cb/cc66f83b5de3

Conversation

@AztecBot
Copy link
Copy Markdown
Collaborator

Motivation

The previous bucket-reduction stage runs one workgroup per window with
batch-affine Montgomery-trick adds. At small c (e.g. c=8) each
window has only 2^(c-1)=128 buckets to drain and the workgroup is
mostly idle, while batch-affine forces each thread to gather S=8
points before any work happens.

This PR replaces that stage with a fully parallel Jacobian tree
reduction that treats all NW * 2^(c-1) weighted buckets as one block
of work and dispatches every merge in parallel. Each thread does a
small constant-time merge with no batched inversion and no
S-point gather, so GPU saturation no longer depends on per-window
fan-out.

Algorithm

For each window w, each tree node summarises a contiguous range of
buckets as a pair (S, W):

  • S = Σ B[k] over the range,
  • W = Σ (pos · B[k]) with pos = 1..h relative to the range start.

Merging two adjacent (S_L, W_L) and (S_R, W_R) of size h each into
one (S, W) of size 2h:

S    = S_L + S_R                       (1 group add)
hS_R = double S_R, log2(h) times       (l Jacobian doublings)
W    = W_L + hS_R + W_R                (2 group adds)

At the root, W is L_w = Σ k · B_w[k] for k = 1..N. Per window the
total is ≈ 5N/2 adds + ≈ N doublings — fewer rounds than the old
4-phase reduction and with full inter-window parallelism throughout.

Round 0 (AA -> J, leaf) and rounds 1..c-2 (JJ -> J) are
implemented as separate shaders, because the first level can use the
cheaper mmadd/madd formulas (Z = 1 both inputs) and skip Jacobian
loads entirely. As required by the design, neither shader checks the
x1 == x2 case — SRS bases are randomly-independent generators and
coordinate collisions in the reduction are vanishingly improbable.

Layout

  • (S, W) pairs live in a 6-plane SoA buffer (S.X, S.Y, S.Z, W.X, W.Y, W.Z).
  • Two ping-pong buffers cover all rounds; the final round writes to a
    dedicated jbrFinalBuf whose plane stride equals NUM_WINDOWS, so
    the per-window L_w gather is a contiguous 3-plane copy into
    redStaging.
  • The host receives NW Jacobian points and Horner-combines them in
    Jacobian (one inversion at the very end). The C++ bridge consumer
    still sees per-window affine sums via windowSums (one inversion
    per window) for combineOnHost = false.

Files

  • wgsl/cuzk/jbr_aa_to_jj.template.wgsl — round 0 (AA -> J leaf merge)
  • wgsl/cuzk/jbr_jj_to_jj.template.wgsl — rounds 1..c-2 (JJ -> J merge)
  • cuzk/shader_manager.ts — two new gen_jbr_* shader builders
  • msm_v2.ts — replaces reduceInit + reduceLevel* dispatch with
    the AA->J + (c-2)×JJ->J schedule, rewires the L_w gather to the
    Jacobian buffer, and converts hostWindowCombine to take Jacobian
    inputs (the existing Horner already ran in Jacobian internally).
  • jbr_reference.test.mjs — standalone bigint reference that verifies
    the (S, W) merge formula matches Σ k · B[k] for c = 2..10.
  • dev/msm-webgpu/main.ts + scripts/run-browserstack.mjs — adds the
    ?autorun=msm-gpu-noble mode (and an --autorun flag on the
    driver) so a WebGPU-only correctness check against noble can run on
    BrowserStack without paying for the bb.js WASM bootstrap.

Test plan

  • node src/msm_webgpu/jbr_reference.test.mjs passes for c=2..10
    (bigint reference matches a naive Σ k · B[k] MSM).
  • Browser correctness via ?autorun=msm-gpu-noble&logn=14 on a
    BrowserStack macOS Sequoia / Chrome target (M2 Mini).
  • Bench ?autorun=msm-cross-check once a slot frees up to measure
    c=8 and c=13 against the existing batch-affine reduction.

Created by claudebox · group: slackbot

@AztecBot AztecBot added the claudebox Owned by claudebox. it can push to this PR. label May 23, 2026
AztecBot added 12 commits May 23, 2026 10:10
bucket_result stores (0,0) for empty buckets; the JBR formulas were
feeding those non-points into mmadd/madd/add-2007-bl and propagating
garbage through the tree. Adds per-node `meta = is_present | unitp<<1`:
empty leaves fall through case (0,0); case (1,0)/(0,1) at round 0 lift
the single bucket to Jacobian with unitp = 1/2.

Round r >= 1 case (0,1) needed extra care — when the right child is a
single-bucket subtree with R.unitp == h (= 2^r), `h*R.S` and `R.W` are
both the Jacobian form of the SAME 2^r·B[k] and the standard jacAdd
hits the doubling case. Detect via meta and shortcut to jacDouble(R.W);
the standard formula remains safe everywhere else (multi-bucket R mixes
distinct generators; unit R with R.unitp != h gives distinct group
elements).

Host-side hostWindowCombine now uses jacAddSafe / jacDoubleSafe + a
JAC_INF sentinel so all-empty windows propagate as point-at-infinity
through the Horner.

`jbr_reference.test.mjs` extended with empty-bucket cases
(top-empty, first-empty, sparse-top, sparse-mid, alternating,
all-empty). sparse-mid exercises the unitp==h collision chain that
the BrowserStack run originally exposed.
#2 from the audit: manually inline jac_add / jac_double in the AA->J and
JJ->J shaders, break the case-(1,1) hot path into scoped stages, and
defer loads to first use. WGSL doesn't guarantee that function calls
with array<u32,8> by-value parameters inline cleanly; the previous
shader passed 6 such arrays per jac_add call and three jacAdd calls
per merge, easily exceeding the per-thread vector-register budget on
Adreno (and squeezing it on M2).

The JJ hot path now has four scoped stages — S, doublings, W_tmp, W_new
— with sl*/sr*/wl*/wr* loaded inside the stage that consumes them.
Stage outputs that bridge stages (dx/dy/dz, wtx/wty/wtz) are outer-scope
vars rewritten in place; stage-internal intermediates fall out of the
live set at the closing brace.

#1 from the audit: pickReduceWg is now a flat 128 regardless of c. The
old c-tiered table (32/64/128) was tuned for the batch-affine kernel
where workgroup size capped at 2^(c-1); the new flat-tree dispatch
doesn't have that constraint, and 128-thread WGs occupy a core fully
without leaving simdgroups idle on late, sparse rounds.

Reference test (jbr_reference.test.mjs) unchanged and still passes.
GPU correctness was last verified on Apple M2; an S25 (Adreno 750)
bench follows.

Also brings the autorun=msm-gpu-bench page mode in dev/msm-webgpu so
the next sweep can be driven from a single BrowserStack URL.
…iler limit

The previous commit inlined the Jacobian formulas straight into the AA and
JJ merge kernels. On the Apple M2 this compiles fine; on a Samsung
Galaxy S25 (Adreno 750 / Snapdragon 8 Elite) the Vulkan driver returns
VK_ERROR_UNKNOWN from CreateComputePipelines — the post-mustache shader
exceeds Adreno's per-kernel size/code-cache budget. Reverting both
shaders to the function-call form; pickReduceWg flat 128 stays.

S25 with the functional form runs the bench end-to-end at logN=14, c=8:
wall_min 21.4 ms / wall_mean 29.3 ms (thermal throttling visible across
the 20-sample run). redLevel = 2.08 ms (min), redInit = 0.07 ms (min).

Will attack register pressure via the field-element representation
instead — vec4<u32>×2 in place of array<u32,8> doesn't grow the shader
text, and it gives the compiler 2 vector-register slots per field
instead of 8 scalar slots.
…loads

Two pieces aimed at reducing per-thread register pressure and the
per-round kernel-dispatch overhead:

1. jbr_window_coop kernel — for c ≤ 8 the per-window (S, W) tree fits
   N/2 ≤ 64 nodes × 196 B per node = 12.5 KiB inside a workgroup's
   threadgroup memory (under the 16 KiB WebGPU spec minimum). One
   workgroup owns one window: thread tid does the AA→J pair-merge,
   workgroupBarrier, then each subsequent JJ→J sub-round halves the
   active thread count with another barrier in between. Replaces the
   ~7 individual dispatches the c=8 redLevel previously needed with a
   single dispatch (workgroup_size = N/2, dispatch_count = NW).
   MsmV2.create gates on c ≤ 8; larger c stays on the multi-dispatch
   JJ path, which would breach the 16 KiB TG budget.

2. JJ shader — defer wl* / wr* loads to the scoped block that
   consumes them. With every load up-front the wl* and wr* fields
   stayed live throughout the S-stage jac_add, inflating the
   per-thread peak live-set. Loading inside `{}` so the compiler can
   drop the temporaries at scope close. (Inlining failed Adreno's
   shader compiler last attempt — staying functional here.)

Both changes preserve the existing case-split for empty buckets and
single-bucket subtree positions; the jbr_reference.test.mjs cases
(sparse-top/mid, alternating, all-empty) still pass.

Apple M2 (chrome 148) noble check at logN=14 still matches.
Samsung S25 (Adreno 750) bench pending.
S25 spotted: 'active' is reserved in current WGSL. Trivial rename of the
loop counter in jbr_window_coop.
S25 bench (logN=14, c=8) ran the WG-coop kernel and hit wall_min 293 ms
(redInit alone 257 ms) vs 21.5 ms baseline. The combined single-dispatch
shader with 6 sub-rounds + 6 case-splits compiles cleanly but Adreno's
register allocator clearly spills almost everything to global memory.

Keep the shader + plumbing; just gate the host pipeline compilation off
(). The multi-dispatch JJ path remains the active code
path for every c. The deferred wl* / wr* loads from the previous commit
benefit that path directly and weren't being exercised because c=8 was
taking the WG-coop branch.
EOF
Goal: cut per-thread register pressure across the heavy S-stage jac_add
on Adreno. Source-level deferred loads didn't move S25's redLevel
because the compiler hoists the wl*/wr* loads back to function entry —
it has nothing else to do during the initial S-stage scheduling and
treats the loads as free latency hiding. WGSL's memory model lets us
force the issue: a workgroup-memory write of wl/wr at entry plus a
workgroup-memory read at the W stage is non-reorderable, so the
compiler MUST drop the values from registers between the write and the
read.

Layout: `tg_wlwr: array<vec4<u32>, 12 * WG>` — 6 field-elements ×
2 vec4 per thread = 192 B per thread. WG=128 → 24 KiB per workgroup,
inside the 32 KiB Adreno typically exposes (gpu.ts already requests
the adapter max). The Apple M2 fits this easily.

Six fields = wl{x,y,z} + wr{x,y,z}. They were the only loads
overlapping the S-stage jac_add; sl/sr stay register-resident because
they're consumed inside the same scope as the S compute.

Algorithm is unchanged — jbr_reference.test.mjs still passes every
case (sparse-top, sparse-mid, first-empty, top-empty, alternating,
all-empty).
First WG-mem attempt put wl+wr through TG (24 KiB at WG=128) — Adreno
returns VK_ERROR_UNKNOWN on CreateComputePipelines, almost certainly
because 24 KiB is over the device's real TG cap even though gpu.ts
requests the adapter max. wr* is the long-lived one (survives until
the final jac_add at stage D); wl* is consumed at stage C, soon after
the doublings, so a source-level deferred load is enough for it.
Footprint drops to 12 KiB (96 B × WG=128), inside the 16 KiB WebGPU
minimum every device has.
The (S, W) tree has hit Adreno's per-merge compute floor and source-level
register-pressure tricks (vec4 packing, deferred loads, wr-only TG
offload) don't move the 2.097 ms redLevel needle on S25.

This adds a different algorithm — bit decomposition — gated by the new
`reduction: 'jbr-sw' | 'bdr'` MsmConfig field. For each window w and
bit j in [0, c-2], compute G[w,j] = Σ B[k] over k with bit_j(k) = 1
via NW × (c-1) independent pure-sum trees; then combine into
L_w = Σ 2^j · G[w,j] with a per-window Horner kernel. j = c-1 is
trivially lifted from B[N] inside the Horner.

Per-merge state shrinks substantially:
- JBR JJ case (1,1): 3 jac_add + r doublings = 48 + 8r mults,
  ~25 fields peak (12 input + jac_add internals + carry).
- BDR JJ: 1 jac_add = 16 mults, ~16 fields peak.

Per-round thread count grows 7× for c=8: NW × (c-1) × tree-stage-size
vs NW × tree-stage-size. The smallest BDR round still has NW × (c-1) =
224 threads vs JBR's NW = 32, so Adreno's late-round occupancy
improves and the per-merge cost should drop accordingly.

Default stays 'jbr-sw' so existing call sites are unchanged. Pass
`?reduction=bdr` to the dev page to opt in. The BDR Horner kernel
writes its output into the same jbrFinalBuf W planes + jbrPresenceFinal
that JBR uses, so run()'s gather + host Horner combine are identical
for both paths.

S25 benchmark vs the current JBR baseline to follow.
…count

Noble check at logN=14 returned mismatching gpu vs noble for BDR. The
Horner shader's g_plane_stride was set to NUM_WINDOWS * (c-1) (224 for
c=8) — the live G value count — but the buffer that the last JJ round
wrote into has stride planeA or planeB (whichever ping-pong slot it
landed in: planeB = 3584 for c=8). The Horner then reads Y / Z planes
at base * 224 instead of base * 3584, off by 15× per plane.

Pass the actual prevStride into Horner instead.
Acknowledging the audit conclusion: c=8 is leaving 4.5× per-mult
throughput on the table vs c=13 because the smallest JJ round dispatches
only ~32 threads. Source-level register-pressure tricks didn't move
the floor; algorithm-level parallelism does.

Cooperative jac_add: 4 threads per merge sharing intermediates via
workgroup memory across the 6 dependency levels of add-2007-bl.
- L0  (4 mp): z1z1, z2z2, y1z2, y2z1 — one per lane.
- L1  (4 mp): u1, u2, s1, s2.
- L2a (3 mp): i = twoh², r2 = r², zsum² = zsum² (lane 3 idle).
- L2b (2 mp): j = h*i, v = u1*i (lanes 2-3 idle).
- L2c (3 mp): rvx3 = r*vx3, s1j = s1*j, z3 = zdelta*h.
Five workgroupBarriers between levels.

All active lanes in each level execute the SAME operation kind (every
level is "every active lane does one mp on different operands"), so
SIMD lockstep doesn't penalise the cooperation. Per-thread peak live
state shrinks from JBR JJ's ~16 fields to ~6; dispatched thread count
per JJ round grows 4×, lifting c=8's smallest-round dispatch from 32
to 128 — closer to c=13's 640-thread saturation regime.

WG hard-coded to 64 (4 lanes × 16 merges per WG). TG memory: 16 slots
per merge × 16 merges × 32 B = 8 KiB per WG — well inside any
device's workgroup-storage budget.

Plumbed end-to-end as a new MsmConfig variant 'bdr-coop'; the existing
'bdr' (non-coop BDR) and 'jbr-sw' (default) paths are unchanged.
Out-of-bounds / non-(1,1) lanes still participate in every barrier so
the WG-wide synchronisation stays well-defined.

S25 noble check + bench to follow once the BS slot opens.
First coop pass had `switch (lane) { case 0u: { v = mp(...); } ... }`
which compiles to mp-per-case. On SIMD architectures (Adreno, M2) each
case executes with the other lanes masked, so the four mp calls
serialise — no parallelism gained. Per-merge cost stayed identical to
non-coop BDR.

Restructure each level so the switch only chooses the (a, b) operands
(cheap register moves) and a single uniform `montgomery_product_f8(a, b)`
call site sits outside the switch. All active lanes hit the same mp
call simultaneously with their own (a, b); SIMD lockstep runs them in
parallel. Idle lanes pass (0, 0) so the discarded result still keeps
the call site uniform.

If this didn't help either, the next move is co-Z-friendly batch
arithmetic (which does need a batched inversion, as you pointed out)
or a smaller Mont-mult body.
EOF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

claudebox Owned by claudebox. it can push to this PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant