diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index 985269e67774..c95017844a7c 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -129,6 +129,7 @@ export class ShaderManager { input_size: number, curveConfig: CurveConfig = BN254_CURVE_CONFIG, force_recompile = false, + streamed_yuval = false, ) { this.curveConfig = curveConfig; this.p = curveConfig.baseFieldModulus; @@ -182,8 +183,11 @@ export class ShaderManager { // Render the Karatsuba+Yuval Mont body once. This is the default // u32 multiplier used by every MSM shader that includes the - // `montgomery_product_funcs` mustache partial. - this.mont_product_src = this.renderKaratYuvalMont(); + // `montgomery_product_funcs` mustache partial. `streamed_yuval` + // interleaves Yuval reduce iters 0..9 between the lo and hi/cr + // half-products so the t[0..9] accumulator GPRs die before hi/cr + // run — see `renderKaratYuvalMont` for the algebraic argument. + this.mont_product_src = this.renderKaratYuvalMont({ streamed: streamed_yuval }); if (force_recompile) { const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32; @@ -660,7 +664,8 @@ ${packLines.join('\n')} // sub-sub-products → inner combines → outer combine → Yuval reduce → // final canonicalize) via mustache `{{#each}}` sections. The TS here // just provides the index arrays + r_inv limb constants. - private renderKaratYuvalMont(): string { + private renderKaratYuvalMont(opts: { streamed?: boolean } = {}): string { + const { streamed = false } = opts; const N = this.num_words; // 20 const WS = this.word_size; // 13 const W = 1n << BigInt(WS); @@ -734,6 +739,24 @@ ${packLines.join('\n')} { tag: 'hi', llB: [10], hhB: [15], cB: [10, 15], folds: [{ off: 20, sign: '+' }, { off: 10, sign: '-' }] }, { tag: 'cr', llB: [0, 10], hhB: [5, 15], cB: [0, 5, 10, 15], folds: [{ off: 10, sign: '+' }] }, ]; + // Streamed-Yuval helper: WGSL for one Yuval reduce iter `i` (reads t[i], + // folds carry + t_mask*R_INV_j into t[i+1..i+N]). Only used in streamed + // mode, where we inline iters 0..9 between the `lo` group and `hi`/`cr` + // so that t[0..9] are dead before the high-pressure half-products run. + const renderYuvalIterText = (i: number): string => { + const lines: string[] = []; + lines.push(' {'); + lines.push(` let t_mask: u32 = t${i} & MASK;`); + lines.push(` let carry: u32 = t${i} >> WORD_SIZE;`); + for (let j = 0; j < N; j++) { + const slot = i + 1 + j; + const carryAdd = j === 0 ? ' + carry' : ''; + lines.push(` t${slot} = t${slot} + t_mask * R_INV_${j}${carryAdd};`); + } + lines.push(' }'); + return lines.join('\n'); + }; + const mb: string[] = []; for (let s = 0; s < 2 * N; s++) mb.push(` var t${s}: u32 = 0u;`); for (const g of kgroups) { @@ -755,11 +778,23 @@ ${packLines.join('\n')} mb.push(` { let p: u32 = ${pExpr(k)}; ${folds} }`); } mb.push(' }'); + if (streamed && g.tag === 'lo') { + // After `lo` closes, t[0..9] are final (lo writes t[k]+= for k=0..18; + // hi and cr never touch t[0..9]). Yuval iter i reads only t[i], so + // iters 0..9 can run here; t[0..9] then die before hi/cr's + // 27-schoolbook + ~15-limb live set lands. + mb.push(''); + mb.push(' // ===== streamed Yuval iter 0..9 (t[0..9] final after lo) ====='); + for (let i = 0; i < 10; i++) mb.push(renderYuvalIterText(i)); + } } const multiply_body = mb.join('\n'); + // In streamed mode the first 10 iters are inlined above; the template + // emits iters 10..18 plus the standard reduce + final drain unchanged. + const yuvalStart = streamed ? 10 : 0; const yuval_iters: Array<{ i: number; writes: Array<{ slot: number; r_idx: number; first: boolean }> }> = []; - for (let i = 0; i < N - 1; i++) { + for (let i = yuvalStart; i < N - 1; i++) { const writes = []; for (let j = 0; j < N; j++) { writes.push({ slot: i + 1 + j, r_idx: j, first: j === 0 }); diff --git a/barretenberg/ts/src/msm_webgpu/msm_v2.ts b/barretenberg/ts/src/msm_webgpu/msm_v2.ts index 5273ef8c2bad..c9263d53c5f7 100644 --- a/barretenberg/ts/src/msm_webgpu/msm_v2.ts +++ b/barretenberg/ts/src/msm_webgpu/msm_v2.ts @@ -64,6 +64,15 @@ export interface MsmConfig { profile?: boolean; /** Phase-2 hook — Jacobian-crossover threshold. Accepted but inert in Phase 1. */ jacobianCrossover?: number; + /** + * Render the Karatsuba+Yuval `montgomery_product` body with reduce iters + * 0..9 streamed between the `lo` and `hi`/`cr` half-products. Drops the + * t[0..9] accumulator GPRs before the high-pressure half-products run — + * a small register-pressure win on Adreno (~5 ms / 7 % median wall at + * n=2^16 on S25 Ultra). Default false; no effect on Metal/Intel where + * the kernel isn't register-bound. + */ + streamedYuval?: boolean; /** * Discarded warm-up `run()`s in `create()` — they ramp the GPU clock and pay * the shader-JIT / command-buffer cold start before the first timed run. @@ -594,7 +603,7 @@ export class MsmV2 { const misc = compute_misc_params(FP, 13); m.R = misc.r; m.rinv = misc.rinv; - const sm = new ShaderManager(4, n, BN254_CURVE_CONFIG, false); + const sm = new ShaderManager(4, n, BN254_CURVE_CONFIG, false, config?.streamedYuval ?? false); // Bind a prefix of the shared, already-Montgomery-converted SRS pool. The // level-0 kernels index points by `val_idx < n`, so a pool with srsN >= n