Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 39 additions & 4 deletions barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ export class ShaderManager {
input_size: number,
curveConfig: CurveConfig = BN254_CURVE_CONFIG,
force_recompile = false,
streamed_yuval = false,
) {
this.curveConfig = curveConfig;
this.p = curveConfig.baseFieldModulus;
Expand Down Expand Up @@ -182,8 +183,11 @@ export class ShaderManager {

// Render the Karatsuba+Yuval Mont body once. This is the default
// u32 multiplier used by every MSM shader that includes the
// `montgomery_product_funcs` mustache partial.
this.mont_product_src = this.renderKaratYuvalMont();
// `montgomery_product_funcs` mustache partial. `streamed_yuval`
// interleaves Yuval reduce iters 0..9 between the lo and hi/cr
// half-products so the t[0..9] accumulator GPRs die before hi/cr
// run — see `renderKaratYuvalMont` for the algebraic argument.
this.mont_product_src = this.renderKaratYuvalMont({ streamed: streamed_yuval });

if (force_recompile) {
const rand = Math.round(Math.random() * 100000000000000000) % 2 ** 32;
Expand Down Expand Up @@ -660,7 +664,8 @@ ${packLines.join('\n')}
// sub-sub-products → inner combines → outer combine → Yuval reduce →
// final canonicalize) via mustache `{{#each}}` sections. The TS here
// just provides the index arrays + r_inv limb constants.
private renderKaratYuvalMont(): string {
private renderKaratYuvalMont(opts: { streamed?: boolean } = {}): string {
const { streamed = false } = opts;
const N = this.num_words; // 20
const WS = this.word_size; // 13
const W = 1n << BigInt(WS);
Expand Down Expand Up @@ -734,6 +739,24 @@ ${packLines.join('\n')}
{ tag: 'hi', llB: [10], hhB: [15], cB: [10, 15], folds: [{ off: 20, sign: '+' }, { off: 10, sign: '-' }] },
{ tag: 'cr', llB: [0, 10], hhB: [5, 15], cB: [0, 5, 10, 15], folds: [{ off: 10, sign: '+' }] },
];
// Streamed-Yuval helper: WGSL for one Yuval reduce iter `i` (reads t[i],
// folds carry + t_mask*R_INV_j into t[i+1..i+N]). Only used in streamed
// mode, where we inline iters 0..9 between the `lo` group and `hi`/`cr`
// so that t[0..9] are dead before the high-pressure half-products run.
const renderYuvalIterText = (i: number): string => {
const lines: string[] = [];
lines.push(' {');
lines.push(` let t_mask: u32 = t${i} & MASK;`);
lines.push(` let carry: u32 = t${i} >> WORD_SIZE;`);
for (let j = 0; j < N; j++) {
const slot = i + 1 + j;
const carryAdd = j === 0 ? ' + carry' : '';
lines.push(` t${slot} = t${slot} + t_mask * R_INV_${j}${carryAdd};`);
}
lines.push(' }');
return lines.join('\n');
};

const mb: string[] = [];
for (let s = 0; s < 2 * N; s++) mb.push(` var t${s}: u32 = 0u;`);
for (const g of kgroups) {
Expand All @@ -755,11 +778,23 @@ ${packLines.join('\n')}
mb.push(` { let p: u32 = ${pExpr(k)}; ${folds} }`);
}
mb.push(' }');
if (streamed && g.tag === 'lo') {
// After `lo` closes, t[0..9] are final (lo writes t[k]+= for k=0..18;
// hi and cr never touch t[0..9]). Yuval iter i reads only t[i], so
// iters 0..9 can run here; t[0..9] then die before hi/cr's
// 27-schoolbook + ~15-limb live set lands.
mb.push('');
mb.push(' // ===== streamed Yuval iter 0..9 (t[0..9] final after lo) =====');
for (let i = 0; i < 10; i++) mb.push(renderYuvalIterText(i));
}
}
const multiply_body = mb.join('\n');

// In streamed mode the first 10 iters are inlined above; the template
// emits iters 10..18 plus the standard reduce + final drain unchanged.
const yuvalStart = streamed ? 10 : 0;
const yuval_iters: Array<{ i: number; writes: Array<{ slot: number; r_idx: number; first: boolean }> }> = [];
for (let i = 0; i < N - 1; i++) {
for (let i = yuvalStart; i < N - 1; i++) {
const writes = [];
for (let j = 0; j < N; j++) {
writes.push({ slot: i + 1 + j, r_idx: j, first: j === 0 });
Expand Down
11 changes: 10 additions & 1 deletion barretenberg/ts/src/msm_webgpu/msm_v2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ export interface MsmConfig {
profile?: boolean;
/** Phase-2 hook — Jacobian-crossover threshold. Accepted but inert in Phase 1. */
jacobianCrossover?: number;
/**
* Render the Karatsuba+Yuval `montgomery_product` body with reduce iters
* 0..9 streamed between the `lo` and `hi`/`cr` half-products. Drops the
* t[0..9] accumulator GPRs before the high-pressure half-products run —
* a small register-pressure win on Adreno (~5 ms / 7 % median wall at
* n=2^16 on S25 Ultra). Default false; no effect on Metal/Intel where
* the kernel isn't register-bound.
*/
streamedYuval?: boolean;
/**
* Discarded warm-up `run()`s in `create()` — they ramp the GPU clock and pay
* the shader-JIT / command-buffer cold start before the first timed run.
Expand Down Expand Up @@ -594,7 +603,7 @@ export class MsmV2 {
const misc = compute_misc_params(FP, 13);
m.R = misc.r;
m.rinv = misc.rinv;
const sm = new ShaderManager(4, n, BN254_CURVE_CONFIG, false);
const sm = new ShaderManager(4, n, BN254_CURVE_CONFIG, false, config?.streamedYuval ?? false);

// Bind a prefix of the shared, already-Montgomery-converted SRS pool. The
// level-0 kernels index points by `val_idx < n`, so a pool with srsN >= n
Expand Down
Loading