From 2782b15470c8a0a64bdfc784434b5e3712790b2b Mon Sep 17 00:00:00 2001 From: AztecBot Date: Fri, 22 May 2026 22:16:09 +0000 Subject: [PATCH] perf(bb/msm): streamed-Yuval karat-stream montmul knob MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a `MsmConfig.streamedYuval` switch that renders the Karatsuba+Yuval montgomery_product body with reduce iters 0..9 inlined between the `lo` and `hi`/`cr` half-products. After `lo` closes, t[0..9] are final (lo writes t[k]+= for k=0..18; hi and cr only touch t[10..38]), so iters 0..9 can run there and t[0..9] die before hi/cr's 27-schoolbook + ~15-limb live set lands — about 10 fewer accumulator GPRs alive through the high-pressure half-products. The knob is off by default (no behaviour change on existing callers). Threaded through `ShaderManager` via a new optional `streamed_yuval` ctor param; only the production MSM `ShaderManager` in `MsmV2.create` flips it. Same arithmetic and same template — `multiply_body` carries the inlined iters and the mustache `yuval_iters` section starts at i=10 instead of 0. Measured on BrowserStack real devices at n=2^18 × k=100 (field-mul microbench) and n=2^16 c=13 (full MSM via the dev page): S25 Ultra field-mul: 48.3 ms -> 47.1 ms (-2.5 %) S25 Ultra MSM 2^16: 65.4 ms -> 60.7 ms median (-7.2 %) S25 Ultra MSM 2^16: 58.0 ms -> 59.0 ms min (within noise) M2 / Chrome MSM 2^16: 40.4 ms streamed vs ~40 ms baseline (flat, not register-bound, no regression) Shader compile cost on Adreno is ~3 % lower with the streamed body (register allocator's job is easier when t[0..9] are dead before hi/cr). No correctness change: validated bit-exact against the host BigInt reference at 64 random scalar/point pairs on both Metal and Adreno paths. --- .../ts/src/msm_webgpu/cuzk/shader_manager.ts | 43 +++++++++++++++++-- barretenberg/ts/src/msm_webgpu/msm_v2.ts | 11 ++++- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 985269e67774..c95017844a7c 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -129,6 +129,7 @@ export class ShaderManager { input_size: number, curveConfig: CurveConfig = BN254_CURVE_CONFIG, force_recompile = false, + streamed_yuval = false, ) { this.curveConfig = curveConfig; this.p = curveConfig.baseFieldModulus; @@ -182,8 +183,11 @@ export class ShaderManager { // Render the Karatsuba+Yuval Mont body once. This is the default // u32 multiplier used by every MSM shader that includes the - // `montgomery_product_funcs` mustache partial. - this.mont_product_src = this.renderKaratYuvalMont(); + // `montgomery_product_funcs` mustache partial. `streamed_yuval` + // interleaves Yuval reduce iters 0..9 between the lo and hi/cr + // half-products so the t[0..9] accumulator GPRs die before hi/cr + // run — see `renderKaratYuvalMont` for the algebraic argument. + this.mont_product_src = this.renderKaratYuvalMont({ streamed: streamed_yuval }); if (force_recompile) { const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32; @@ -660,7 +664,8 @@ ${packLines.join('\n')} // sub-sub-products → inner combines → outer combine → Yuval reduce → // final canonicalize) via mustache `{{#each}}` sections. The TS here // just provides the index arrays + r_inv limb constants. - private renderKaratYuvalMont(): string { + private renderKaratYuvalMont(opts: { streamed?: boolean } = {}): string { + const { streamed = false } = opts; const N = this.num_words; // 20 const WS = this.word_size; // 13 const W = 1n << BigInt(WS); @@ -734,6 +739,24 @@ ${packLines.join('\n')} { tag: 'hi', llB: [10], hhB: [15], cB: [10, 15], folds: [{ off: 20, sign: '+' }, { off: 10, sign: '-' }] }, { tag: 'cr', llB: [0, 10], hhB: [5, 15], cB: [0, 5, 10, 15], folds: [{ off: 10, sign: '+' }] }, ]; + // Streamed-Yuval helper: WGSL for one Yuval reduce iter `i` (reads t[i], + // folds carry + t_mask*R_INV_j into t[i+1..i+N]). Only used in streamed + // mode, where we inline iters 0..9 between the `lo` group and `hi`/`cr` + // so that t[0..9] are dead before the high-pressure half-products run. + const renderYuvalIterText = (i: number): string => { + const lines: string[] = []; + lines.push(' {'); + lines.push(` let t_mask: u32 = t${i} & MASK;`); + lines.push(` let carry: u32 = t${i} >> WORD_SIZE;`); + for (let j = 0; j < N; j++) { + const slot = i + 1 + j; + const carryAdd = j === 0 ? ' + carry' : ''; + lines.push(` t${slot} = t${slot} + t_mask * R_INV_${j}${carryAdd};`); + } + lines.push(' }'); + return lines.join('\n'); + }; + const mb: string[] = []; for (let s = 0; s < 2 * N; s++) mb.push(` var t${s}: u32 = 0u;`); for (const g of kgroups) { @@ -755,11 +778,23 @@ ${packLines.join('\n')} mb.push(` { let p: u32 = ${pExpr(k)}; ${folds} }`); } mb.push(' }'); + if (streamed && g.tag === 'lo') { + // After `lo` closes, t[0..9] are final (lo writes t[k]+= for k=0..18; + // hi and cr never touch t[0..9]). Yuval iter i reads only t[i], so + // iters 0..9 can run here; t[0..9] then die before hi/cr's + // 27-schoolbook + ~15-limb live set lands. + mb.push(''); + mb.push(' // ===== streamed Yuval iter 0..9 (t[0..9] final after lo) ====='); + for (let i = 0; i < 10; i++) mb.push(renderYuvalIterText(i)); + } } const multiply_body = mb.join('\n'); + // In streamed mode the first 10 iters are inlined above; the template + // emits iters 10..18 plus the standard reduce + final drain unchanged. + const yuvalStart = streamed ? 10 : 0; const yuval_iters: Array<{ i: number; writes: Array<{ slot: number; r_idx: number; first: boolean }> }> = []; - for (let i = 0; i < N - 1; i++) { + for (let i = yuvalStart; i < N - 1; i++) { const writes = []; for (let j = 0; j < N; j++) { writes.push({ slot: i + 1 + j, r_idx: j, first: j === 0 }); diff --git a/barretenberg/ts/src/msm_webgpu/msm_v2.ts b/barretenberg/ts/src/msm_webgpu/msm_v2.ts index 5273ef8c2bad..c9263d53c5f7 100644 --- a/barretenberg/ts/src/msm_webgpu/msm_v2.ts +++ b/barretenberg/ts/src/msm_webgpu/msm_v2.ts @@ -64,6 +64,15 @@ export interface MsmConfig { profile?: boolean; /** Phase-2 hook — Jacobian-crossover threshold. Accepted but inert in Phase 1. */ jacobianCrossover?: number; + /** + * Render the Karatsuba+Yuval `montgomery_product` body with reduce iters + * 0..9 streamed between the `lo` and `hi`/`cr` half-products. Drops the + * t[0..9] accumulator GPRs before the high-pressure half-products run — + * a small register-pressure win on Adreno (~5 ms / 7 % median wall at + * n=2^16 on S25 Ultra). Default false; no effect on Metal/Intel where + * the kernel isn't register-bound. + */ + streamedYuval?: boolean; /** * Discarded warm-up `run()`s in `create()` — they ramp the GPU clock and pay * the shader-JIT / command-buffer cold start before the first timed run. @@ -594,7 +603,7 @@ export class MsmV2 { const misc = compute_misc_params(FP, 13); m.R = misc.r; m.rinv = misc.rinv; - const sm = new ShaderManager(4, n, BN254_CURVE_CONFIG, false); + const sm = new ShaderManager(4, n, BN254_CURVE_CONFIG, false, config?.streamedYuval ?? false); // Bind a prefix of the shared, already-Montgomery-converted SRS pool. The // level-0 kernels index points by `val_idx < n`, so a pool with srsN >= n