diff --git a/barretenberg/ts/dev/msm-webgpu/main.ts b/barretenberg/ts/dev/msm-webgpu/main.ts index 08c24fd60492..fc6909b63b90 100644 --- a/barretenberg/ts/dev/msm-webgpu/main.ts +++ b/barretenberg/ts/dev/msm-webgpu/main.ts @@ -29,6 +29,24 @@ import { createWasmPippenger, parseAffineLE, type WasmPippengerHandle } from './ import { loadSrsPoints, type SrsEvent } from './srs.js'; import { makeResultsClient } from './results_post.js'; +// BrowserStack's /5/worker launcher truncates the device URL at the first +// unescaped `&`, so only the first query param survives (every later param +// silently falls back to its page default). To pass a full param set to a +// BrowserStack run, base64-encode the intended query string and send it as a +// single `?cfg=` param (no `&`, nothing to truncate). We decode it +// here — before any URLSearchParams read — and rewrite location.search so the +// rest of the page sees the expanded params unchanged. +(() => { + try { + const cfg = new URLSearchParams(window.location.search).get('cfg'); + if (cfg) { + history.replaceState(null, '', window.location.pathname + '?' + atob(cfg)); + } + } catch { + // leave the URL as-is if cfg is malformed + } +})(); + type LogLevel = 'info' | 'ok' | 'err' | 'warn'; // Per-rep profiling capture consumed by the sweep aggregator. `runWebGpuOnce` @@ -64,7 +82,7 @@ const $results = document.getElementById('results') as HTMLDivElement; // The sweep spans 2^10..2^20 — small sizes show where the GPU pipeline // overtakes the WASM Pippenger; the v2 pipeline has no size floor. -const LOGN_MIN = 10; +const LOGN_MIN = 8; const LOGN_MAX = 20; const SRS_NUM_POINTS = 1 << LOGN_MAX; @@ -89,6 +107,8 @@ const gpuKnobs: MsmConfig = (() => { reduceWg: optInt('reducewg'), l0Log: optInt('l0log'), invVariant: q.get('inv') === 'loop' ? 'loop' : q.get('inv') === 'pk' ? 'pk' : undefined, + accum: q.get('accum') === 'coop' ? 'coop' : q.get('accum') === 'walker' ? 'walker' : undefined, + coopG: optInt('coopg'), profile: q.get('profile') === '1' || q.get('autorun') === 'msm-bench' || undefined, }; })(); @@ -1370,19 +1390,86 @@ function hideProgress(): void { // Sweep — ensureWasmBooted() takes care of the first-click boot. The SRS // fetch is just a download + JS-side decompression (no native workers), // so it's safe to run unconditionally at page load. +// For the msm-accum-ab BrowserStack sweep the SRS download runs in this boot +// block — before the autorun branch — and on mobile it dominates the wall. We +// emit progress from here (and reuse this client in the autorun branch) so the +// runner's first-progress/stall watchdog sees the SRS phase and both phases +// share ONE runId. A boot-start ping also distinguishes "device never reached +// the page / can't POST" from "device is busy loading SRS". +const accumAbActive = new URLSearchParams(window.location.search).get('autorun') === 'msm-accum-ab'; +const accumAbClient = accumAbActive ? makeResultsClient({ page: 'msm-accum-ab' }) : null; + (async () => { + // `?autorun=env-probe` — minimal device-capability probe. Posts UA + + // navigator.gpu presence + adapter info BEFORE any SRS download or WebGPU + // pipeline build, then exits. Isolates "can this BrowserStack browser load + // the page and POST back" from "does it expose WebGPU" from "can it pull the + // SRS over the tunnel" — the three confounded failure modes that made the + // earlier Android attempts ambiguous ("no heartbeat" told us nothing). + if (new URLSearchParams(window.location.search).get('autorun') === 'env-probe') { + const client = makeResultsClient({ page: 'env-probe' }); + const hasGpu = 'gpu' in navigator && !!(navigator as unknown as { gpu?: unknown }).gpu; + client.postProgress({ phase: 'boot-start', ua: navigator.userAgent, webgpu: hasGpu }); + const probe: Record = { ua: navigator.userAgent, webgpu: hasGpu }; + try { + if (hasGpu) { + const gpu = (navigator as unknown as { gpu: GPU }).gpu; + const adapter = await gpu.requestAdapter(); + probe.adapterRequested = !!adapter; + if (adapter) { + const info = (adapter as unknown as { info?: GPUAdapterInfo }).info ?? {}; + probe.adapterInfo = { + vendor: info.vendor, + architecture: info.architecture, + device: info.device, + description: info.description, + }; + const lim = adapter.limits as unknown as Record; + probe.limits = { + maxComputeWorkgroupStorageSize: lim?.maxComputeWorkgroupStorageSize, + maxStorageBuffersPerShaderStage: lim?.maxStorageBuffersPerShaderStage, + maxComputeInvocationsPerWorkgroup: lim?.maxComputeInvocationsPerWorkgroup, + maxComputeWorkgroupSizeX: lim?.maxComputeWorkgroupSizeX, + }; + } + } + } catch (e) { + probe.error = e instanceof Error ? `${e.message}` : String(e); + } + client.postProgress({ phase: 'probe-done', ...probe }); + await client.postResults({ state: 'done', params: { page: 'env-probe' }, results: probe }); + setBusy(false); + return; + } + setBusy(true, 'loading SRS…'); try { - srsBuf = await loadSrsPoints(SRS_NUM_POINTS, event => { + accumAbClient?.postProgress({ phase: 'boot-start', ua: navigator.userAgent, webgpu: 'gpu' in navigator }); + let lastSrsPost = 0; + // Optional `?srs_logn=N` caps the SRS download to 2^N points instead of + // the full 2^LOGN_MAX. A mobile BrowserStack run that only sweeps up to + // logn 16 has no reason to pull 16× the data (2^20) over the tunnel — the + // oversized download was overrunning the first-progress watchdog. + const srsLognCap = parseInt(new URLSearchParams(window.location.search).get('srs_logn') ?? '', 10); + const srsNumPoints = + Number.isFinite(srsLognCap) && srsLognCap >= LOGN_MIN && srsLognCap <= LOGN_MAX + ? 1 << srsLognCap + : SRS_NUM_POINTS; + srsBuf = await loadSrsPoints(srsNumPoints, event => { if (event.kind === 'info') { log('info', event.msg); } else if (event.kind === 'phase') { renderProgress(event); + // Heartbeat at most every 5 s so the watchdog sees the SRS download. + if (accumAbClient && performance.now() - lastSrsPost > 5000) { + lastSrsPost = performance.now(); + accumAbClient.postProgress({ phase: 'srs', srsPhase: event.phase, current: event.current, total: event.total }); + } } else if (event.kind === 'done') { hideProgress(); } }); - log('ok', `SRS loaded: ${SRS_NUM_POINTS.toLocaleString()} points available.`); + log('ok', `SRS loaded: ${srsNumPoints.toLocaleString()} points available.`); log( 'info', `WASM not booted yet (lazy). Click Run / Sweep — it'll spin up ` + @@ -1405,7 +1492,288 @@ function hideProgress(): void { // can pick them up from JSONL. const qp = new URLSearchParams(window.location.search); const autorun = qp.get('autorun'); - if (autorun === 'msm-bench') { + if (autorun === 'msm-noble') { + // Direct GPU-vs-noble cross-check at an arbitrary (small) logN. Unlike + // msm-cross-check (which compares GPU against the WASM Pippenger and only + // consults noble at logN=16), this generates the points/scalars with a + // noble mirror, runs the GPU MSM, and compares the affine result to + // noble's reference. Headless-SwiftShader friendly: no WASM boot. + const autorunLogN = parseInt(qp.get('logn') ?? '10', 10); + const client = makeResultsClient({ page: 'msm-noble' }); + log('info', `[noble-xcheck] logN=${autorunLogN}`); + try { + const inputs = await generateInputs(autorunLogN, /*mirrorForNoble=*/ true); + const { xy } = await runWebGpuOnce(inputs); + const ref = referenceMsm(inputs.points!, inputs.scalars!); + const agree = pointsEqual(xy, ref); + if (agree) { + log('ok', `[noble-xcheck] cross-check agree (gpu == noble) at logN=${autorunLogN}`); + } else { + log('err', `[noble-xcheck] cross-check disagreement: gpu.x=0x${xy.x.toString(16)} noble.x=0x${ref.x.toString(16)}`); + } + const state = agree ? 'done' : 'error'; + await client.postResults({ + state, + params: { logN: autorunLogN, page: 'msm-noble' }, + results: { cross_ok: agree, gpu_x: '0x' + xy.x.toString(16), noble_x: '0x' + ref.x.toString(16) }, + error: agree ? null : 'gpu != noble', + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log(state === 'done' ? 'ok' : 'err', `[autorun] state=${state}`); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `[noble-xcheck] FATAL: ${msg}`); + await client.postResults({ + state: 'error', + params: { logN: autorunLogN, page: 'msm-noble' }, + results: null, + error: msg, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('err', `[autorun] state=error`); + } + } else if (autorun === 'msm-accum-ab') { + // Same-device A/B benchmark of both bucket-accumulate kernels in ONE page + // load (identical thermal state, one BrowserStack worker for the pair). + // GPU-only, no WASM, no noble. Posts {walker, coop} min/median ms. + const logns = (qp.get('logns') ?? qp.get('logn') ?? '14') + .split(',').map(x => parseInt(x, 10)).filter(x => Number.isFinite(x)); + const reps = parseInt(qp.get('reps') ?? '12', 10); + const order = (qp.get('order') ?? 'walker,coop').split(',') as ('walker' | 'coop')[]; + const client = accumAbClient ?? makeResultsClient({ page: 'msm-accum-ab' }); + log('info', `[ab] logns=${logns.join(',')} reps=${reps} order=${order.join(',')}`); + try { + // Heartbeat into the JSONL while waiting for the SRS load (the only + // thing gating $run) so the BrowserStack first-progress/stall watchdog + // sees life during the (mobile, over-the-tunnel) download. + for (let i = 0; i < 1200; i++) { + if (!$run.disabled) break; + if (i % 20 === 0) client.postProgress({ phase: 'waiting-srs', i }); + await new Promise(r => setTimeout(r, 500)); + } + const device = await get_device(); + // Global warmup so the first real sweep entry doesn't eat one-time + // driver JIT / cold GPU-clock cost (otherwise the first logN reads + // anomalously high — e.g. seconds instead of ms). + { + const wi = await generateInputs(logns[0], false); + const wpool = await MsmV2Pool.create(device, wi.pointsBuf); + const wmsm = await MsmV2.create(device, wi.n, wpool, { ...gpuKnobs, accum: 'walker' }); + wmsm.prepare(wi.scalarsBuf); + await wmsm.run(); + await wmsm.run(); + wmsm.destroy(); + wpool.destroy(); + } + // Heartbeat so the BrowserStack watchdog sees first-progress before the + // (slow on mobile) per-(logN,accum) pipeline builds + reps run. The + // msm-accum-ab sweep otherwise only posts once at the very end. + client.postProgress({ phase: 'warmup-done', logns: logns.join(',') }); + const sweep: Record[] = []; + for (const logN of logns) { + const inputs = await generateInputs(logN, /*mirrorForNoble=*/ false); + const pool = await MsmV2Pool.create(device, inputs.pointsBuf); + const out: Record = {}; + for (const accum of order) { + client.postProgress({ phase: 'build', logN, accum }); + const msm = await MsmV2.create(device, inputs.n, pool, { ...gpuKnobs, accum }); + msm.prepare(inputs.scalarsBuf); + await msm.run(); // warmup (untimed) + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const t0 = performance.now(); + await msm.run(); + samples.push(performance.now() - t0); + client.postProgress({ phase: 'rep', logN, accum, rep: r, ms: samples[r] }); + } + msm.destroy(); + const sorted = [...samples].sort((a, b) => a - b); + const min = sorted[0]; + const median = sorted[Math.floor(sorted.length / 2)]; + const avg = samples.reduce((a, b) => a + b, 0) / samples.length; + out[accum] = { min, median, avg, samples }; + log('ok', `[ab] logN=${logN} accum=${accum}: min=${min.toFixed(1)} median=${median.toFixed(1)} ms`); + } + pool.destroy(); + const speedup = out.walker && out.coop ? out.walker.min / out.coop.min : null; + if (speedup !== null) { + log('ok', `[ab] logN=${logN} coop speedup vs walker (min): ${speedup.toFixed(3)}x`); + } + sweep.push({ logN, ...out, speedup_min: speedup }); + } + await client.postResults({ + state: 'done', + params: { logns: logns.join(','), reps, page: 'msm-accum-ab' }, + results: { reps, sweep }, + error: null, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('ok', `[autorun] state=done`); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `[ab] FATAL: ${msg}`); + await client.postResults({ + state: 'error', + params: { logns: logns.join(','), reps, page: 'msm-accum-ab' }, + results: null, + error: msg, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('err', `[autorun] state=error`); + } + } else if (autorun === 'msm-coop-gsweep') { + // Same-device granularity sweep of the coop-walker in ONE page load: a + // walker baseline plus coop at each inversion-granularity G (threads per + // shared batched inversion). One BrowserStack worker measures the whole + // curve under identical thermal state. GPU-only, no WASM, no noble. + const logns = (qp.get('logns') ?? qp.get('logn') ?? '14') + .split(',').map(x => parseInt(x, 10)).filter(x => Number.isFinite(x)); + const reps = parseInt(qp.get('reps') ?? '12', 10); + const gs = (qp.get('gsweep') ?? '1,8,16,32,64') + .split(',').map(x => parseInt(x, 10)).filter(x => Number.isFinite(x) && x > 0); + const client = makeResultsClient({ page: 'msm-coop-gsweep' }); + log('info', `[gsweep] logns=${logns.join(',')} reps=${reps} G=${gs.join(',')}`); + const benchOne = async ( + device: GPUDevice, pool: MsmV2Pool, inputs: Awaited>, + cfg: Parameters[3], + ) => { + const msm = await MsmV2.create(device, inputs.n, pool, cfg); + msm.prepare(inputs.scalarsBuf); + await msm.run(); // warmup (untimed) + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const t0 = performance.now(); + await msm.run(); + samples.push(performance.now() - t0); + } + msm.destroy(); + const sorted = [...samples].sort((a, b) => a - b); + return { min: sorted[0], median: sorted[Math.floor(sorted.length / 2)], + avg: samples.reduce((a, b) => a + b, 0) / samples.length, samples }; + }; + try { + for (let i = 0; i < 1200; i++) { + if (!$run.disabled) break; + await new Promise(r => setTimeout(r, 500)); + } + const device = await get_device(); + { + // Global warmup so the first timed entry doesn't eat cold-clock cost. + const wi = await generateInputs(logns[0], false); + const wpool = await MsmV2Pool.create(device, wi.pointsBuf); + const wmsm = await MsmV2.create(device, wi.n, wpool, { ...gpuKnobs, accum: 'walker' }); + wmsm.prepare(wi.scalarsBuf); + await wmsm.run(); + await wmsm.run(); + wmsm.destroy(); + wpool.destroy(); + } + const sweep: Record[] = []; + for (const logN of logns) { + const inputs = await generateInputs(logN, /*mirrorForNoble=*/ false); + const pool = await MsmV2Pool.create(device, inputs.pointsBuf); + const walker = await benchOne(device, pool, inputs, { ...gpuKnobs, accum: 'walker' }); + log('ok', `[gsweep] logN=${logN} walker: min=${walker.min.toFixed(1)} median=${walker.median.toFixed(1)} ms`); + client.postProgress({ logN, config: 'walker', min: walker.min, median: walker.median }); + const coop: Record = {}; + for (const g of gs) { + const r = await benchOne(device, pool, inputs, { ...gpuKnobs, accum: 'coop', coopG: g }); + coop[`g${g}`] = r; + log('ok', `[gsweep] logN=${logN} coop G=${g}: min=${r.min.toFixed(1)} median=${r.median.toFixed(1)} ms ` + + `(speedup vs walker ${(walker.min / r.min).toFixed(3)}x)`); + client.postProgress({ logN, config: `coop_g${g}`, min: r.min, median: r.median, speedup: walker.min / r.min }); + } + pool.destroy(); + const speedups: Record = {}; + for (const g of gs) speedups[`g${g}`] = walker.min / coop[`g${g}`].min; + sweep.push({ logN, walker, coop, speedup_min: speedups }); + } + await client.postResults({ + state: 'done', + params: { logns: logns.join(','), reps, gsweep: gs.join(','), page: 'msm-coop-gsweep' }, + results: { reps, gs, sweep }, + error: null, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('ok', `[autorun] state=done`); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `[gsweep] FATAL: ${msg}`); + await client.postResults({ + state: 'error', + params: { logns: logns.join(','), reps, gsweep: gs.join(','), page: 'msm-coop-gsweep' }, + results: null, + error: msg, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('err', `[autorun] state=error`); + } + } else if (autorun === 'msm-accum-bench') { + // GPU-only timed-reps benchmark for the bucket-accumulate kernel + // (selected via ?accum=walker|coop). Calls runWebGpuOnce directly — no + // WASM boot (so the 213-byte stub is irrelevant) and no noble — so it runs + // on real BrowserStack devices that lack the bb wasm. Posts per-rep ms + + // min/median for the run() window. + const autorunLogN = parseInt(qp.get('logn') ?? '16', 10); + const reps = parseInt(qp.get('reps') ?? '10', 10); + const accum = gpuKnobs.accum ?? 'walker'; + const client = makeResultsClient({ page: 'msm-accum-bench' }); + log('info', `[accum-bench] logN=${autorunLogN} reps=${reps} accum=${accum}`); + try { + // Wait for SRS (Run button leaves the perpetually-disabled state). + for (let i = 0; i < 1200; i++) { + if (!$run.disabled) break; + await new Promise(r => setTimeout(r, 500)); + } + const inputs = await generateInputs(autorunLogN, /*mirrorForNoble=*/ false); + const samples: number[] = []; + for (let r = 0; r < reps; r++) { + const { ms } = await runWebGpuOnce(inputs); + samples.push(ms); + log('info', `[accum-bench] rep ${r + 1}/${reps}: ${ms.toFixed(1)} ms`); + } + const sorted = [...samples].sort((a, b) => a - b); + const min = sorted[0]; + const median = sorted[Math.floor(sorted.length / 2)]; + const avg = samples.reduce((a, b) => a + b, 0) / samples.length; + log('ok', `[accum-bench] DONE accum=${accum} logN=${autorunLogN}: min=${min.toFixed(1)} median=${median.toFixed(1)} avg=${avg.toFixed(1)} ms`); + await client.postResults({ + state: 'done', + params: { logN: autorunLogN, reps, accum, page: 'msm-accum-bench' }, + results: { accum, logN: autorunLogN, samples, min, median, avg }, + error: null, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('ok', `[autorun] state=done`); + } catch (e) { + const msg = e instanceof Error ? `${e.message}\n${e.stack}` : String(e); + log('err', `[accum-bench] FATAL: ${msg}`); + await client.postResults({ + state: 'error', + params: { logN: autorunLogN, reps, accum, page: 'msm-accum-bench' }, + results: null, + error: msg, + log: [], + userAgent: navigator.userAgent, + hardwareConcurrency: navigator.hardwareConcurrency, + }); + log('err', `[autorun] state=error`); + } + } else if (autorun === 'msm-bench') { const autorunLogN = parseInt(qp.get('logn') ?? '17', 10); const reps = parseInt(qp.get('reps') ?? '5', 10); const client = makeResultsClient({ page: 'msm-bench' }); diff --git a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs index 3ae3843e2f55..fb9a8ebfceff 100644 --- a/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs +++ b/barretenberg/ts/dev/msm-webgpu/scripts/run-browserstack.mjs @@ -60,6 +60,8 @@ const { values: argv } = parseArgs({ "skip-tunnel": { type: "boolean", default: false }, "list-targets": { type: "boolean", default: false }, autorun: { type: "string", default: "msm-cross-check" }, + accum: { type: "string" }, + extra: { type: "string" }, "emit-body-only": { type: "boolean", default: false }, "external-worker-id-file": { type: "string" }, help: { type: "boolean", default: false }, @@ -417,7 +419,13 @@ async function main() { qp.set("autorun", argv.autorun); qp.set("logn", String(argv.n ?? "16")); if (argv.reps) qp.set("reps", String(argv.reps)); - const pageUrl = `${baseUrl}${pageMap[argv.page]}?${qp.toString()}`; + if (argv.accum) qp.set("accum", String(argv.accum)); + let extraQs = ""; + if (argv.extra) { + // Raw extra query params (e.g. "inv=loop&s=4"), appended verbatim. + extraQs = (argv.extra.startsWith("&") ? "" : "&") + String(argv.extra); + } + const pageUrl = `${baseUrl}${pageMap[argv.page]}?${qp.toString()}${extraQs}`; err(`page URL: ${pageUrl}`); // Generate a runId on the client side (the page makes its own random diff --git a/barretenberg/ts/src/msm_webgpu/COOP_WALKER_DESIGN.md b/barretenberg/ts/src/msm_webgpu/COOP_WALKER_DESIGN.md new file mode 100644 index 000000000000..c4ead168c244 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/COOP_WALKER_DESIGN.md @@ -0,0 +1,135 @@ +# Cooperative-inversion bucket accumulator ("coop-walker") + +Re-architecture of the MSM bucket-accumulate stage for laptop + mobile GPUs. +Grounded in the measured ground truth that the stream-walker accumulate kernel +is **memory-bound / occupancy-limited**, not inversion-bound. + +## Measured starting point (not re-derived) + +- The stream-walker accumulate kernel is extremely memory-bound on real + hardware. safegcd inversion *looks* like ~47% of the walker wall only + because memory stalls dilate it — the identical safegcd is <30% of MsmV2, + which is not memory-starved. Lever = **hide memory latency** + (occupancy / coalescing / fewer dependent gathers), not cheaper inversion. +- The stream-walker is per-thread bucket-monotonic. Each thread serially + walks a contiguous bucket range carrying **S** independent slot accumulators + in private registers, and stages forward-prefix products through a + `var pref_scratch` sized `TPB*S*2` vec4 = **16 KB at TPB=64,S=8**. + That long per-thread serial dependency chain + the 16 KB workgroup footprint + are why occupancy is low and memory latency is not hidden. +- Mobile reality: 16 KB workgroup memory (Mali) / 32 KB (Apple, Adreno); + only 10 storage buffers per stage on many mobile adapters; Android Chrome + has no timestamp-query (wall-time only). + +## Why the walker is occupancy-starved + +Two coupled costs both scale with **S** (slots per thread): + +1. **Register pressure.** Per slot the walker keeps `acc_x[8] + acc_y[8]` + (16 u32) plus 8 bookkeeping `array` (cursor, bucket_end, + task_end_sort, task_end_cur, cur_sorted, cur_bucket, is_first, slot_done, + split_start). At S=8 that is ~150+ live registers per invocation → few + resident invocations → memory latency is exposed. +2. **Workgroup memory.** `pref_scratch = TPB*S*2` vec4 = 16 KB at TPB=64,S=8. + On Mali (16 KB total shared) this caps the core to **one resident + workgroup**. No second workgroup means barriers and dependent gathers in + the resident workgroup stall the whole core. + +Shrinking S (the sibling "S-sweep") trades inversion amortization for +occupancy but leaves the *structure* — long per-thread serial chain, per-slot +carried state — intact. This design changes the structure instead. + +## The structural change: share one inversion across the workgroup + +Set **slots-per-thread = 1**. Each thread is a plain serial walker over one +contiguous slice of the sorted bucket stream (reusing the existing +`thread_cuts` partition unchanged). The batched-inversion that made affine +adds cheap is moved from *per-thread over S slots* to *per-workgroup over TPB +threads*: + +- Each round, every active thread produces exactly one `dx` for its pending + affine add (a retired thread contributes `dx = R`, Montgomery one, which is + inert). +- The workgroup computes the batch inverse of the TPB `dx` values + cooperatively: an exclusive **prefix-product scan** and an exclusive + **suffix-product scan** in workgroup memory, then a **single** safegcd + inversion of the workgroup-wide product (one thread), then + `inv_dx_t = inv_total * pre[t] * suf[t]`. +- Each thread applies its affine add with its `inv_dx_t` and advances. + +### What this buys, on every axis the ground truth cares about + +| Axis | stream-walker (TPB=64,S=8) | coop-walker (TPB=64,S=1) | +|---|---|---| +| Live registers / invocation | ~150+ (scales with S) | ~20 (one accumulator) | +| Workgroup memory | 16 KB (`TPB*S*2` vec4) | ~6 KB (dx + pre + suf, `3*TPB*2` vec4) | +| Independent adds in flight / round | S=8 per thread | TPB=64 per workgroup | +| safegcd inversions | ≈ total_adds / S | ≈ total_adds / TPB (**~8× fewer**) | +| Mali resident workgroups / core | 1 (16 KB cap) | ≥2 (6 KB) | + +Lower registers + lower workgroup memory → **higher occupancy** → more +resident workgroups to hide memory latency (the MsmV2 win) while still +**streaming** each point from global memory exactly once (the walker memory +footprint — no pair-tree materialization). The cooperative scan adds +`2*log2(TPB)` barriers per round, but with high occupancy those barriers are +hidden by sibling workgroups — exactly the latency-hiding regime MsmV2 proves +is reachable on this hardware. + +Fewer total inversions (~8×) is a bonus, not the point: the wall is memory, +and a shorter per-invocation serial chain with far more resident invocations +is what hides it. + +## I/O contract (drop-in for the existing pipeline) + +The coop kernel replaces only the `stream_walker` accumulate dispatch. It +reuses the entire surrounding pipeline (decompose → transpose → planner → +reduce → `walker_partials_index` → `walker_combine`) and keeps the exact same +output contract: + +- A bucket fully owned within one thread's range → full EC sum written to + `bucket_sums[bucket_id]`, no partial. +- A bucket split across a thread boundary → each thread writes its piece's + partial-sum to a unique slot (`2*t+0` split-start suffix, `2*t+1` task-end + prefix) with `partial_dest[slot] = bucket_id`; `walker_combine` sums them. +- Unused partial slots → `partial_dest = NO_BUCKET`. + +Because there is no S sub-split, the coop kernel emits **fewer** partials than +the walker (boundaries only at thread cuts, not task cuts), which also reduces +exposure to the known `walker_combine` `dx==0` incomplete-affine-add bug. + +## Status + +- [x] Headless-SwiftShader GPU-vs-noble cross-check harness + (`autorun=msm-noble`), GREEN at logn=8 and logn=10 for walker and coop. +- [x] coop-walker kernel + host wiring (selectable via `accum` knob, with the + inversion-granularity knob `G`). +- [x] cross-check coop at logn 8/10, multiple configs incl. `accum:'auto'`. +- [x] BrowserStack real-hardware A/B vs the stream-walker on real Apple M2, + Adreno (S25 Ultra), and Mali (Pixel 9 Pro XL). See PR #23739 for the + tables. + +### Measured outcome (the design's prediction was half-right) + +The occupancy thesis holds **on Adreno**: coop at **G=1** (each thread inverts +its own dx — no workgroup memory, no in-loop barriers, one accumulator/thread, +maximal occupancy) is **1.67–2.05× faster than the stream-walker** on a Galaxy +S25 Ultra across logN 12/14/16. The win comes purely from occupancy hiding +memory latency — and it does so *despite* G=1 doing ~S× MORE safegcd inversions +than the walker's S-wide batch (the design's "fewer inversions" via the +workgroup scan is irrelevant; the scan, G=TPB, is in fact the slowest coop mode +and triggers a device-lost on Adreno at logN≥14). + +It does **not** generalise: on Mali (Pixel 9 Pro XL) coop G=1 wins only at logN +12 and regresses at 14/16, and on Apple M2 it loses above logN 12 — neither +hides G=1's extra inversions the way Adreno does. So `accum:'auto'` selects +coop G=1 **only on Adreno/Qualcomm** and keeps the walker everywhere else. + +## Alternatives considered (documented, not pursued first) + +- **Stage points in bucket-sorted order** to remove the two-hop dependent + gather (`l0_index[cursor]` → `point_x[2*pt]`) from the hot loop and coalesce + reads. Rejected as the first move because the staging buffer (~n·64 B) adds + memory and a full extra streaming pass on an already memory-bound kernel; + worth revisiting as a workgroup-memory tile rather than a global buffer. +- **Drop Montgomery form** for the modest muls/element. Orthogonal to the + occupancy problem; not the lever. diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts b/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts index 39c9533d90cb..1a06949c1cda 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/gpu.ts @@ -48,9 +48,28 @@ export const get_device = async (): Promise => { const device = await adapter.requestDevice({ requiredFeatures, requiredLimits }); const grantedLimits = device.limits as unknown as Record; + // Stash the adapter's info for the per-device kernel selection in + // `resolveAccum` (the bucket-accumulate kernel choice keys off + // vendor/architecture). Newer Chrome exposes a read-only `device.adapterInfo` + // getter (which resolveAccum reads first); on engines without it we keep a + // copy under a private key. Never assign to `adapterInfo` — it is getter-only + // where present and assignment throws. + const adapterInfo = + (adapter as unknown as { info?: GPUAdapterInfo }).info ?? + (typeof (adapter as unknown as { requestAdapterInfo?: () => Promise }).requestAdapterInfo === 'function' + ? await (adapter as unknown as { requestAdapterInfo: () => Promise }).requestAdapterInfo() + : undefined); + if (adapterInfo) { + try { + (device as unknown as { __adapterInfo?: GPUAdapterInfo }).__adapterInfo = adapterInfo; + } catch { + // ignore — resolveAccum falls back to the native device.adapterInfo getter + } + } console.log( `[gpu] requested maxComputeWorkgroupStorageSize=${wgStorageMax}B,` + - ` granted=${grantedLimits['maxComputeWorkgroupStorageSize']}B`, + ` granted=${grantedLimits['maxComputeWorkgroupStorageSize']}B` + + ` adapter="${adapterInfo?.vendor ?? ''}/${adapterInfo?.architecture ?? ''}"`, ); return device; }; diff --git a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts index ac4992d67ded..07bc30a3b40f 100644 --- a/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts +++ b/barretenberg/ts/src/msm_webgpu/cuzk/shader_manager.ts @@ -20,6 +20,7 @@ import { ba_partial_sum as ba_partial_sum_shader, // Stream-walker (STREAM_WALKER_PLAN.md §6, plus C's KNOB 2 variant). ba_planner_partition_task as ba_planner_partition_task_shader, + ba_coop_walker as ba_coop_walker_shader, ba_stream_walker as ba_stream_walker_shader, ba_walker_combine as ba_walker_combine_shader, ba_walker_partials_index as ba_walker_partials_index_shader, @@ -929,6 +930,48 @@ ${packLines.join('\n')} ); } + // Cooperative-inversion bucket accumulator (coop-walker). Drop-in for the + // stream_walker dispatch (same bind group + indirect args): one task per + // thread, with the batched inversion shared across the workgroup via a + // prefix/suffix product scan + a single safegcd inversion per round. + public gen_ba_coop_walker_shader( + workgroup_size: number, + s: number, + variant: 'loop' | 'pk' = 'pk', + g: number = workgroup_size, + ): string { + const dec = this.decoupledPackUnpackWgsl(); + const inverse_funcs = by_inverse_loop_funcs; + const inv_fn = variant === 'pk' ? 'fr_inv_by_loop_pk' : 'fr_inv_by_loop'; + const { p8_consts, r8_csv, f8_words } = this.f8Context(); + // Inversion granularity: G==TPB -> cooperative prefix/suffix scan; + // 1 per-group serial Montgomery batch inversion; G==1 -> each + // thread inverts its own dx (no workgroup memory, no barriers). + const gClamped = Math.max(1, Math.min(g, workgroup_size)); + const coop_local = gClamped === 1; + const coop_scan = gClamped >= workgroup_size; + const coop_group = !coop_local && !coop_scan; + return mustache.render( + ba_coop_walker_shader, + { + workgroup_size, s, inv_fn, g: gClamped, + coop_local, coop_scan, coop_group, + p8_consts, r8_csv, f8_words, + word_size: this.word_size, num_words: this.num_words, n0: this.n0, + p_limbs: this.p_limbs, r_limbs: this.r_limbs, r_cubed_limbs: this.r_cubed_limbs, + p_minus_2_limbs: this.p_minus_2_limbs, mask: this.mask, + two_pow_word_size: this.two_pow_word_size, p_inv_mod_2w: this.p_inv_mod_2w, + p_inv_by_a_lo: this.p_inv_by_a_lo, + dec_unpack: dec.unpack, dec_pack: dec.pack, recompile: this.recompile, + }, + { + structs, bigint_funcs, + montgomery_product_funcs: this.mont_product_src, + field_funcs, field8_funcs, fr_pow_funcs, bigint_by_funcs, inverse_funcs, + }, + ); + } + // Stream-walker partials indexer (task #19): one thread per partial slot. // Builds a per-bucket linked list in (bucket_head, nodes_slot, nodes_next) // using atomicCompareExchange — replaces walker_combine's O(num_dense × diff --git a/barretenberg/ts/src/msm_webgpu/msm_v2.ts b/barretenberg/ts/src/msm_webgpu/msm_v2.ts index e623b6f881de..bc39dac36b83 100644 --- a/barretenberg/ts/src/msm_webgpu/msm_v2.ts +++ b/barretenberg/ts/src/msm_webgpu/msm_v2.ts @@ -64,6 +64,26 @@ export interface MsmConfig { profile?: boolean; /** Phase-2 hook — Jacobian-crossover threshold. Accepted but inert in Phase 1. */ jacobianCrossover?: number; + /** + * Bucket-accumulate kernel. + * - 'walker' = per-thread S-slot stream-walker (each thread inverts its own + * S-wide batch in parallel with every other thread). + * - 'coop' = cooperative-inversion accumulator (one task per thread, a single + * safegcd inversion shared across the workgroup via a prefix/suffix scan). + * - 'auto' (default) = pick per device: 'coop' on memory-/register-starved + * mobile GPUs (Adreno/Mali), 'walker' on cache-rich desktop GPUs. See + * {@link resolveAccum}. + * Drop-in: all reuse the same bind group, indirect dispatch, and combine path. + */ + accum?: 'walker' | 'coop' | 'auto'; + /** + * Cooperative-inversion granularity for `accum: 'coop'`: number of threads + * that share ONE batched inversion. `WALKER_TPB` (default) = workgroup-wide + * prefix/suffix scan; `1 < G < TPB` = per-group serial Montgomery batch + * inversion (TPB/G concurrent inversions); `1` = each thread inverts its own + * dx (no workgroup memory, no barriers). Must divide the workgroup size. + */ + coopG?: number; /** * Discarded warm-up `run()`s in `create()` — they ramp the GPU clock and pay * the shader-JIT / command-buffer cold start before the first timed run. @@ -406,6 +426,55 @@ function pickReduceWg(c: number): number { return c <= 9 ? 32 : c <= 12 ? 64 : 128; } +// `auto` selects the cooperative-inversion accumulator with G=1 only on Adreno +// (Qualcomm) GPUs, where it is measured fastest; everything else uses the +// walker. The choice is data-driven, not a blanket "mobile → coop": +// - Adreno (Galaxy S25 Ultra, Chrome 145): coop G=1 runs 1.67–2.05× faster +// than the stream-walker across logN 12/14/16, speedup decaying monotonically +// with G (g8≈1.5×, g16≈1.2–1.3×, g32≈parity). G=1 = each thread inverts its +// own dx: no workgroup memory, no in-loop barriers, one accumulator/thread, +// so occupancy is maximal — which on Adreno hides the cost of doing ~S× more +// safegcd inversions than the walker's S-wide batch. The workgroup-scan +// default (G=TPB) is the WORST coop mode there and triggered a device-lost +// at logN≥14. +// - Mali (Pixel 9 Pro XL, Chrome 145): coop G=1 wins only at logN 12 (1.31×) +// and REGRESSES at logN 14/16 (0.87×, 0.79×) — Mali does not hide G=1's ~S× +// extra inversions. So Mali stays on the walker. +// - Cache-rich desktop GPUs (Apple M2): the walker wins outright. Walker is +// never selected against here. +// G=1 is the only coop granularity `auto` ever ships; the workgroup-scan default +// (G=TPB=64) is kept only for explicit `accum:'coop'` benchmarking. +const COOP_G_AUTO = 1; +const COOP_G_DEFAULT = 64; + +// True for the GPU family where coop G=1 is measured fastest (Adreno/Qualcomm). +// Newer Chrome exposes a read-only `device.adapterInfo` getter; engines without +// it get the copy stashed by `get_device` under `__adapterInfo`. +function coopWinsHere(device: GPUDevice): { coopWins: boolean; hay: string } { + const info = ((device as unknown as { adapterInfo?: GPUAdapterInfo }).adapterInfo ?? + (device as unknown as { __adapterInfo?: GPUAdapterInfo }).__adapterInfo ?? + {}) as Partial; + const hay = `${info.vendor ?? ''} ${info.architecture ?? ''} ${info.device ?? ''} ${info.description ?? ''}`.toLowerCase(); + return { coopWins: /adreno|qualcomm|snapdragon/.test(hay), hay }; +} + +// Resolve the bucket-accumulate kernel + inversion granularity for this device. +// Explicit `accum` is honoured (G defaults to the scan unless overridden); `auto` +// (the default) picks coop G=1 on Adreno (measured 1.67–2.05× over the walker) +// and the walker everywhere else. An explicit `coopG` always wins. +function resolveAccum( + requestedAccum: 'walker' | 'coop' | 'auto' | undefined, + requestedG: number | undefined, + device: GPUDevice, +): { accum: 'walker' | 'coop'; coopG: number } { + if (requestedAccum === 'walker') return { accum: 'walker', coopG: requestedG ?? COOP_G_DEFAULT }; + if (requestedAccum === 'coop') return { accum: 'coop', coopG: requestedG ?? COOP_G_DEFAULT }; + const { coopWins, hay } = coopWinsHere(device); + const accum: 'walker' | 'coop' = coopWins ? 'coop' : 'walker'; + console.log(`[MsmV2] accum=auto -> '${accum}'${accum === 'coop' ? ` G=${requestedG ?? COOP_G_AUTO}` : ''} (adapter: "${hay.trim()}")`); + return { accum, coopG: requestedG ?? COOP_G_AUTO }; +} + // Per-level GPU dispatch wiring for one prepared scalar set. interface LevelBind { plannerABind: GPUBindGroup; @@ -1265,6 +1334,10 @@ export class MsmV2 { private streamWalkerPipe!: GPUComputePipeline; private streamWalkerLayout!: GPUBindGroupLayout; private streamWalkerBind!: GPUBindGroup; + // Cooperative-inversion accumulator (reuses streamWalkerLayout + bind). + private coopWalkerPipe!: GPUComputePipeline; + private accum: 'walker' | 'coop' = 'walker'; + private coopG = 64; private walkerCombinePipe!: GPUComputePipeline; private walkerCombineLayout!: GPUBindGroupLayout; private walkerCombineBind!: GPUBindGroup; @@ -1384,6 +1457,9 @@ export class MsmV2 { m.invVariant = config?.invVariant ?? DEFAULT_INV_VARIANT; m.addsub = config?.addsub ?? 'native'; m.jacobianCrossover = config?.jacobianCrossover ?? 0; + const resolved = resolveAccum(config?.accum, config?.coopG, device); + m.accum = resolved.accum; + m.coopG = resolved.coopG; m.combineOnHost = config?.combineOnHost ?? true; const wantProfile = config?.profile ?? false; m.profile = wantProfile && device.features.has('timestamp-query'); @@ -1626,6 +1702,13 @@ export class MsmV2 { m.streamWalkerPipe = await compile( sm.gen_ba_stream_walker_shader(WALKER_TPB, STREAM_S, INV_VARIANT), `stream-walker`, m.streamWalkerLayout); + // coop-walker shares the indirect-dispatch grain (ceil(num_active/TPB)) + // and the stream-walker bind group; only compiled when selected. + if (m.accum === 'coop') { + m.coopWalkerPipe = await compile( + sm.gen_ba_coop_walker_shader(WALKER_TPB, STREAM_S, INV_VARIANT, m.coopG), + `coop-walker`, m.streamWalkerLayout); + } m.walkerCombinePipe = await compile( sm.gen_ba_walker_combine_shader(STREAM_S, INV_VARIANT), `walker-combine`, m.walkerCombineLayout); @@ -2369,7 +2452,8 @@ export class MsmV2 { // partition_task wrote the walker's indirect args to planner_meta[15..17] // (= byte offset 60 = 15 * 4). setPhase('stream_walker'); - indirectDispatch(this.streamWalkerPipe, this.streamWalkerBind, spMeta, 15 * 4); + const accumPipe = this.accum === 'coop' ? this.coopWalkerPipe : this.streamWalkerPipe; + indirectDispatch(accumPipe, this.streamWalkerBind, spMeta, 15 * 4); // Task #19: per-bucket linked-list index (atomic CAS pass over // partial_dest) replaces walker_combine's O(num_dense × M_partials) // scan with O(M_partials) indexing + O(num_partials_per_bucket) walks. diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts index e5321b600205..9b1f8bebf9be 100644 --- a/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts +++ b/barretenberg/ts/src/msm_webgpu/wgsl/_generated/shaders.ts @@ -1,6 +1,6 @@ // AUTO-GENERATED by scripts/inline-wgsl.mjs. DO NOT EDIT. // Run `yarn generate:wgsl` (or `node scripts/inline-wgsl.mjs`) to regenerate. -// 43 shader sources inlined. +// 44 shader sources inlined. /* eslint-disable */ @@ -915,6 +915,501 @@ fn bigint_f32_is_zero(x: ptr) -> bool { } `; +export const ba_coop_walker = `{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> inverse_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> field8_funcs }} + +// Cooperative-inversion bucket accumulator ("coop-walker"). +// +// Structural re-architecture of ba_stream_walker. Each thread owns ONE +// contiguous task (its whole [thread_cut, next_thread_cut) range — read as +// cut 0 .. cut S of the per-thread task_cuts block, so this is a drop-in for +// the stream_walker bind group and indirect dispatch). Instead of each thread +// carrying S private slot accumulators and running its own S-wide batched +// inversion through a 16 KB var pref_scratch, the batched inversion +// is shared across the whole workgroup: every active thread contributes one +// dx per round, the workgroup computes all TPB inverses with a cooperative +// prefix/suffix product scan plus a SINGLE safegcd inversion, and each thread +// applies its affine add. +// +// Why: the walker is memory-bound / occupancy-limited; its occupancy is capped +// by per-thread register pressure (~150+ regs at S=8) and the 16 KB workgroup +// footprint (one resident workgroup on Mali). coop-walker drops per-thread +// state to a single accumulator (~20 regs) and workgroup memory to ~4 KB +// (two TPB-wide 256-bit scan arrays), so many more workgroups stay resident to +// hide memory latency — MsmV2-like occupancy at stream-walker memory. +// +// Output contract is identical to ba_stream_walker (so walker_partials_index + +// walker_combine + reduce are reused unchanged): +// - a bucket fully inside one thread's range -> bucket_sums[bucket_id] +// - a bucket split across a thread boundary -> partials at slot +// 2*(t*S+0)+{0,1} (split-start suffix / task-end prefix), summed by +// walker_combine. +// S is retained only for partial-slot layout compatibility with the shared +// partials buffer; coop-walker runs exactly ONE task per thread. +// +// params.x = NUM_THREADS, params.y = IDLE_ANCHOR, +// params.z = M_buckets, params.w = M_partials. + +const S: u32 = {{ s }}u; +const CUTS: u32 = S + 1u; +const TPB: u32 = {{ workgroup_size }}u; +// Inversion granularity: number of threads that share ONE batched inversion. +// G==TPB -> cooperative prefix/suffix scan (one inversion per workgroup). +// 1 per-group serial Montgomery batch inversion (TPB/G inversions, +// one per group leader, run concurrently across leaders). +// G==1 -> each thread inverts its own dx (no workgroup memory, no barriers). +const G: u32 = {{ g }}u; +const PG: u32 = 2u; +const L0_SIGN_BIT: u32 = 0x80000000u; +const L0_IDX_MASK: u32 = 0x7fffffffu; +const NO_BUCKET: u32 = 0xffffffffu; + +@group(0) @binding(0) var sorted_bucket_list: array; +@group(0) @binding(1) var sorted_count_list: array; +@group(0) @binding(2) var offsets: array; +@group(0) @binding(3) var task_cuts: array; +@group(0) @binding(4) var l0_index: array; +@group(0) @binding(5) var point_x: array>; +@group(0) @binding(6) var point_y: array>; +@group(0) @binding(7) var bucket_sums: array>; +@group(0) @binding(8) var partials_buf: array>; +@group(0) @binding(9) var partial_dest: array; +@group(0) @binding(10) var params: vec4; + +{{#coop_scan}} +// Two TPB-wide 256-bit scratch planes for the cooperative batch inversion: +// wpre becomes the inclusive prefix products, wsuf the inclusive suffix +// products. 2 vec4 per slot. ~4 KB total at TPB=64 (vs the walker's 16 KB). +var wpre: array, TPB * 2u>; +var wsuf: array, TPB * 2u>; +var w_inv_total: array, 2u>; +{{/coop_scan}} +{{#coop_group}} +// Per-group serial batch inversion: wdx holds each thread's dx (then is +// overwritten with its inv_dx); wpx holds the running prefix products the +// group leader needs for the backward pass. 2 vec4 per slot, ~4 KB total. +var wdx: array, TPB * 2u>; +var wpx: array, TPB * 2u>; +{{/coop_group}} +{{^coop_local}} +var w_any_active: atomic; +// Mirror of the activity flag read through workgroupUniformLoad so the loop +// break is a provably-uniform value (atomic loads are not, which would make +// the in-loop barriers fail Tint's uniformity analysis). +var w_active_flag: u32; +{{/coop_local}} + +fn load_pt_x(cursor: u32) -> array { + let packed = l0_index[cursor]; + let pt = packed & L0_IDX_MASK; + let q0 = point_x[2u * pt]; + let q1 = point_x[2u * pt + 1u]; + return array(q0.x, q0.y, q0.z, q0.w, q1.x, q1.y, q1.z, q1.w); +} + +fn load_pt_y(cursor: u32) -> array { + let packed = l0_index[cursor]; + let pt = packed & L0_IDX_MASK; + let q0 = point_y[2u * pt]; + let q1 = point_y[2u * pt + 1u]; + let y = array(q0.x, q0.y, q0.z, q0.w, q1.x, q1.y, q1.z, q1.w); + if ((packed & L0_SIGN_BIT) == 0u) { return y; } + let zero = array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); + return fr_sub_f8(zero, y); +} + +fn store_bucket_sum(bucket_id: u32, M: u32, x_val: array, y_val: array) { + let bx = PG * bucket_id; + bucket_sums[bx + 0u] = vec4(x_val[0], x_val[1], x_val[2], x_val[3]); + bucket_sums[bx + 1u] = vec4(x_val[4], x_val[5], x_val[6], x_val[7]); + let by = PG * M + PG * bucket_id; + bucket_sums[by + 0u] = vec4(y_val[0], y_val[1], y_val[2], y_val[3]); + bucket_sums[by + 1u] = vec4(y_val[4], y_val[5], y_val[6], y_val[7]); +} + +fn store_partial(pslot: u32, bucket_id: u32, M: u32, x_val: array, y_val: array) { + let bx = PG * pslot; + partials_buf[bx + 0u] = vec4(x_val[0], x_val[1], x_val[2], x_val[3]); + partials_buf[bx + 1u] = vec4(x_val[4], x_val[5], x_val[6], x_val[7]); + let by = PG * M + PG * pslot; + partials_buf[by + 0u] = vec4(y_val[0], y_val[1], y_val[2], y_val[3]); + partials_buf[by + 1u] = vec4(y_val[4], y_val[5], y_val[6], y_val[7]); + partial_dest[pslot] = bucket_id; +} + +{{#coop_scan}} +fn wstore(arr_pre: bool, l: u32, v: array) { + let a = vec4(v[0], v[1], v[2], v[3]); + let b = vec4(v[4], v[5], v[6], v[7]); + if (arr_pre) { + wpre[2u * l + 0u] = a; + wpre[2u * l + 1u] = b; + } else { + wsuf[2u * l + 0u] = a; + wsuf[2u * l + 1u] = b; + } +} + +fn wload_pre(l: u32) -> array { + let a = wpre[2u * l + 0u]; + let b = wpre[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} + +fn wload_suf(l: u32) -> array { + let a = wsuf[2u * l + 0u]; + let b = wsuf[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} +{{/coop_scan}} +{{#coop_group}} +fn wdx_store(l: u32, v: array) { + wdx[2u * l + 0u] = vec4(v[0], v[1], v[2], v[3]); + wdx[2u * l + 1u] = vec4(v[4], v[5], v[6], v[7]); +} +fn wdx_load(l: u32) -> array { + let a = wdx[2u * l + 0u]; + let b = wdx[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} +fn wpx_store(l: u32, v: array) { + wpx[2u * l + 0u] = vec4(v[0], v[1], v[2], v[3]); + wpx[2u * l + 1u] = vec4(v[4], v[5], v[6], v[7]); +} +fn wpx_load(l: u32) -> array { + let a = wpx[2u * l + 0u]; + let b = wpx[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} +{{/coop_group}} + +// Single field inversion in Montgomery form (unpack -> safegcd -> repack). +fn finv8(v: array) -> array { + var lin = unpack256_to_limbs(v); + var lout = {{ inv_fn }}(lin); + let p = pack_limbs_to_256(&lout); + return array(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7]); +} + +fn coop_is_zero_f8(v: array) -> bool { + return (v[0] | v[1] | v[2] | v[3] | v[4] | v[5] | v[6] | v[7]) == 0u; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let t = gid.x; + let l = lid.x; + let NUM_THREADS = params.x; + let IDLE_ANCHOR = params.y; + let M_buckets = params.z; + let M_partials = params.w; + + // Per-thread scalar state (one task per thread; acc lives in registers). + var cursor: u32 = 0u; // l0_index point position + var bucket_end: u32 = 0u; // l0 position past current bucket + var task_end_sort: u32 = 0u; // sorted index of the task's last bucket + var task_end_cur: u32 = 0u; // l0 position past the task within that bucket + var cur_sorted: u32 = 0u; // index into sorted_bucket_list + var cur_bucket: u32 = 0u; // bucket id (for bucket_sums) + var is_first: u32 = 1u; + var slot_done: u32 = 1u; // default idle (covers t >= active range) + var split_start: u32 = 0u; // current bucket shared with a prior task + var acc_x: array; + var acc_y: array; + + // Slot-layout-compatible NO_BUCKET init: clear all S partial-slot pairs so + // the shared partials buffer is well-defined for walker_partials_index + // (matches ba_stream_walker's coverage of slots 2*(t*S+k)+{0,1}). + if (t < NUM_THREADS) { + for (var k: u32 = 0u; k < S; k = k + 1u) { + partial_dest[2u * (t * S + k) + 0u] = NO_BUCKET; + partial_dest[2u * (t * S + k) + 1u] = NO_BUCKET; + } + } + + // Initialise the single task from cut 0 (start) .. cut S (end). Mirrors + // ba_stream_walker's per-slot init for the whole thread range. + if (t < NUM_THREADS) { + let cut_base = t * CUTS * 2u; + let sb = task_cuts[cut_base + 0u]; + let so = task_cuts[cut_base + 1u]; + let eb = task_cuts[cut_base + S * 2u + 0u]; + let eo = task_cuts[cut_base + S * 2u + 1u]; + + let sb_id = sorted_bucket_list[sb]; + let sb_base = offsets[sb_id]; + let sb_count = sorted_count_list[sb]; + + var eff_sorted = sb; + var eff_id = sb_id; + var eff_base = sb_base; + var eff_count = sb_count; + var start_cursor: u32; + if (so == 0u) { + start_cursor = sb_base; + split_start = 0u; + } else if (so + 1u < sb_count) { + start_cursor = sb_base + so + 1u; + split_start = 1u; + } else { + eff_sorted = sb + 1u; + eff_id = sorted_bucket_list[eff_sorted]; + eff_base = offsets[eff_id]; + eff_count = sorted_count_list[eff_sorted]; + start_cursor = eff_base; + split_start = 0u; + } + + var te_sort: u32; + var te_cur: u32; + if (eo > 0u) { + te_sort = eb; + te_cur = offsets[sorted_bucket_list[eb]] + eo + 1u; + } else if (eb > 0u) { + te_sort = eb - 1u; + let pid = sorted_bucket_list[te_sort]; + te_cur = offsets[pid] + sorted_count_list[te_sort]; + } else { + te_sort = 0u; + te_cur = 0u; + } + + cursor = start_cursor; + bucket_end = eff_base + eff_count; + task_end_sort = te_sort; + task_end_cur = te_cur; + cur_sorted = eff_sorted; + cur_bucket = eff_id; + is_first = 1u; + slot_done = 0u; + + // Empty task (region-aware): start at or past the task end. + if (eff_sorted > te_sort || (eff_sorted == te_sort && start_cursor >= te_cur)) { + slot_done = 1u; + } else { + // Single-point leading segment: a split continuation landing on a + // bucket's last point can't go through the 2-point is_first step. + var seg_end = bucket_end; + if (eff_sorted == te_sort) { seg_end = te_cur; } + if (split_start == 1u && seg_end - start_cursor == 1u) { + let px = load_pt_x(start_cursor); + let py = load_pt_y(start_cursor); + if (eff_sorted == te_sort) { + store_partial(2u * (t * S + 0u) + 1u, eff_id, M_partials, px, py); + slot_done = 1u; + } else { + store_partial(2u * (t * S + 0u) + 0u, eff_id, M_partials, px, py); + let nxt = eff_sorted + 1u; + let nxt_id = sorted_bucket_list[nxt]; + let nxt_base = offsets[nxt_id]; + cur_sorted = nxt; + cur_bucket = nxt_id; + bucket_end = nxt_base + sorted_count_list[nxt]; + cursor = nxt_base; + split_start = 0u; + if (nxt > te_sort) { slot_done = 1u; } + } + } + } + } + + // Main loop. For the cooperative modes (scan / group) every iteration is + // uniform across the workgroup: the break is decided by a workgroup-shared + // activity flag, so all threads execute the same barriers regardless of + // when their own task ends. The local mode (G==1) has no in-loop barriers, + // so each thread simply runs until its own task is done. + loop { +{{^coop_local}} + workgroupBarrier(); + if (l == 0u) { atomicStore(&w_any_active, 0u); } + workgroupBarrier(); + if (slot_done == 0u) { atomicStore(&w_any_active, 1u); } + workgroupBarrier(); + if (l == 0u) { w_active_flag = atomicLoad(&w_any_active); } + // Uniform read (implicit barrier) so the break below is uniform. + let any_active = workgroupUniformLoad(&w_active_flag); + if (any_active == 0u) { break; } +{{/coop_local}} +{{#coop_local}} + if (slot_done == 1u) { break; } +{{/coop_local}} + + // Each active thread computes its pending dx; idle threads contribute + // Montgomery one (inert in the batch product). + var dx: array; + if (slot_done == 1u) { + dx = get_r_f8(); + } else { + var p_lx: array; + var p_rx: array; + if (is_first == 1u) { + p_lx = load_pt_x(cursor); + p_rx = load_pt_x(cursor + 1u); + } else { + p_lx = acc_x; + p_rx = load_pt_x(cursor); + } + dx = fr_sub_f8(p_rx, p_lx); + // Guard the batch product: a zero dx (equal x-coords, a measure-zero + // doubling case for distinct SRS points) would zero the whole group + // product. Substitute one so the failure stays isolated. + if (coop_is_zero_f8(dx)) { dx = get_r_f8(); } + } + +{{#coop_scan}} + wstore(true, l, dx); + wstore(false, l, dx); + + // Interleaved inclusive prefix- (wpre) and suffix- (wsuf) product + // scans (Hillis-Steele). Both share the same step schedule, so fusing + // them halves the workgroup-barrier count versus two sequential scans. + for (var off: u32 = 1u; off < TPB; off = off << 1u) { + workgroupBarrier(); + var ptmp: array; + var stmp: array; + let pact = l >= off; + let sact = l + off < TPB; + if (pact) { ptmp = montgomery_product_f8(wload_pre(l - off), wload_pre(l)); } + if (sact) { stmp = montgomery_product_f8(wload_suf(l), wload_suf(l + off)); } + workgroupBarrier(); + if (pact) { wstore(true, l, ptmp); } + if (sact) { wstore(false, l, stmp); } + } + workgroupBarrier(); + + // One inversion for the whole workgroup: invert the total product. + if (l == 0u) { + let total = wload_pre(TPB - 1u); + var total20 = unpack256_to_limbs(total); + var invtot20 = {{ inv_fn }}(total20); + let invtot = pack_limbs_to_256(&invtot20); + w_inv_total[0u] = vec4(invtot[0], invtot[1], invtot[2], invtot[3]); + w_inv_total[1u] = vec4(invtot[4], invtot[5], invtot[6], invtot[7]); + } + workgroupBarrier(); + let it0 = w_inv_total[0u]; + let it1 = w_inv_total[1u]; + let inv_total = array(it0.x, it0.y, it0.z, it0.w, it1.x, it1.y, it1.z, it1.w); + + // inv_dx_l = inv_total * (prod_{jl} dx_j). + var pre_excl: array; + if (l == 0u) { pre_excl = get_r_f8(); } else { pre_excl = wload_pre(l - 1u); } + var suf_excl: array; + if (l + 1u >= TPB) { suf_excl = get_r_f8(); } else { suf_excl = wload_suf(l + 1u); } + var inv_dx = montgomery_product_f8(inv_total, pre_excl); + inv_dx = montgomery_product_f8(inv_dx, suf_excl); +{{/coop_scan}} +{{#coop_group}} + // Per-group serial Montgomery batch inversion. Each group of G threads + // shares ONE safegcd inversion; the TPB/G group leaders run their + // inversions concurrently (one per leader). Only 2 barriers per round + // regardless of G, versus the scan's 2*log2(TPB). + wdx_store(l, dx); + workgroupBarrier(); + if ((l % G) == 0u) { + // Forward prefix products over the group's G dx values. + var run = wdx_load(l); + wpx_store(l, run); + for (var i: u32 = 1u; i < G; i = i + 1u) { + run = montgomery_product_f8(run, wdx_load(l + i)); + wpx_store(l + i, run); + } + // One inversion of the group product, then the backward pass: each + // inv_dx overwrites its dx slot in wdx. + var inv = finv8(run); + for (var i: u32 = G - 1u; i >= 1u; i = i - 1u) { + let invi = montgomery_product_f8(inv, wpx_load(l + i - 1u)); + inv = montgomery_product_f8(inv, wdx_load(l + i)); + wdx_store(l + i, invi); + } + wdx_store(l, inv); + } + workgroupBarrier(); + let inv_dx = wdx_load(l); +{{/coop_group}} +{{#coop_local}} + // Each thread inverts its own dx — no workgroup memory, no barriers. + let inv_dx = finv8(dx); +{{/coop_local}} + + if (slot_done == 0u) { + var p_lx: array; + var p_ly: array; + var p_rx: array; + var p_ry: array; + if (is_first == 1u) { + p_lx = load_pt_x(cursor); + p_ly = load_pt_y(cursor); + p_rx = load_pt_x(cursor + 1u); + p_ry = load_pt_y(cursor + 1u); + cursor = cursor + 2u; + } else { + p_lx = acc_x; + p_ly = acc_y; + p_rx = load_pt_x(cursor); + p_ry = load_pt_y(cursor); + cursor = cursor + 1u; + } + + var lambda = fr_sub_f8(p_ry, p_ly); + lambda = montgomery_product_f8(lambda, inv_dx); + var r_x = montgomery_product_f8(lambda, lambda); + let x_sum = fr_add_f8(p_lx, p_rx); + r_x = fr_sub_f8(r_x, x_sum); + var r_y = fr_sub_f8(p_lx, r_x); + r_y = montgomery_product_f8(lambda, r_y); + r_y = fr_sub_f8(r_y, p_ly); + + let task_done = (cur_sorted == task_end_sort) && (cursor >= task_end_cur); + let bucket_done = cursor >= bucket_end; + + if (task_done) { + let is_partial = (split_start == 1u) || (cursor < bucket_end); + if (is_partial) { + store_partial(2u * (t * S + 0u) + 1u, cur_bucket, M_partials, r_x, r_y); + } else { + store_bucket_sum(cur_bucket, M_buckets, r_x, r_y); + } + slot_done = 1u; + } else if (bucket_done) { + if (split_start == 1u) { + store_partial(2u * (t * S + 0u) + 0u, cur_bucket, M_partials, r_x, r_y); + } else { + store_bucket_sum(cur_bucket, M_buckets, r_x, r_y); + } + let nxt = cur_sorted + 1u; + let nxt_id = sorted_bucket_list[nxt]; + let nxt_base = offsets[nxt_id]; + cur_sorted = nxt; + cur_bucket = nxt_id; + bucket_end = nxt_base + sorted_count_list[nxt]; + cursor = nxt_base; + is_first = 1u; + split_start = 0u; + } else { + acc_x = r_x; + acc_y = r_y; + is_first = 0u; + } + } + } + + {{{ recompile }}} +} +`; + export const ba_partial_sum = `{{> structs }} {{> bigint_funcs }} {{> montgomery_product_funcs }} diff --git a/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_coop_walker.template.wgsl b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_coop_walker.template.wgsl new file mode 100644 index 000000000000..3f41877173f7 --- /dev/null +++ b/barretenberg/ts/src/msm_webgpu/wgsl/cuzk/ba_coop_walker.template.wgsl @@ -0,0 +1,493 @@ +{{> structs }} +{{> bigint_funcs }} +{{> montgomery_product_funcs }} +{{> field_funcs }} +{{> fr_pow_funcs }} +{{> bigint_by_funcs }} +{{> inverse_funcs }} + +{{{ dec_unpack }}} + +{{{ dec_pack }}} + +{{> field8_funcs }} + +// Cooperative-inversion bucket accumulator ("coop-walker"). +// +// Structural re-architecture of ba_stream_walker. Each thread owns ONE +// contiguous task (its whole [thread_cut, next_thread_cut) range — read as +// cut 0 .. cut S of the per-thread task_cuts block, so this is a drop-in for +// the stream_walker bind group and indirect dispatch). Instead of each thread +// carrying S private slot accumulators and running its own S-wide batched +// inversion through a 16 KB var pref_scratch, the batched inversion +// is shared across the whole workgroup: every active thread contributes one +// dx per round, the workgroup computes all TPB inverses with a cooperative +// prefix/suffix product scan plus a SINGLE safegcd inversion, and each thread +// applies its affine add. +// +// Why: the walker is memory-bound / occupancy-limited; its occupancy is capped +// by per-thread register pressure (~150+ regs at S=8) and the 16 KB workgroup +// footprint (one resident workgroup on Mali). coop-walker drops per-thread +// state to a single accumulator (~20 regs) and workgroup memory to ~4 KB +// (two TPB-wide 256-bit scan arrays), so many more workgroups stay resident to +// hide memory latency — MsmV2-like occupancy at stream-walker memory. +// +// Output contract is identical to ba_stream_walker (so walker_partials_index + +// walker_combine + reduce are reused unchanged): +// - a bucket fully inside one thread's range -> bucket_sums[bucket_id] +// - a bucket split across a thread boundary -> partials at slot +// 2*(t*S+0)+{0,1} (split-start suffix / task-end prefix), summed by +// walker_combine. +// S is retained only for partial-slot layout compatibility with the shared +// partials buffer; coop-walker runs exactly ONE task per thread. +// +// params.x = NUM_THREADS, params.y = IDLE_ANCHOR, +// params.z = M_buckets, params.w = M_partials. + +const S: u32 = {{ s }}u; +const CUTS: u32 = S + 1u; +const TPB: u32 = {{ workgroup_size }}u; +// Inversion granularity: number of threads that share ONE batched inversion. +// G==TPB -> cooperative prefix/suffix scan (one inversion per workgroup). +// 1 per-group serial Montgomery batch inversion (TPB/G inversions, +// one per group leader, run concurrently across leaders). +// G==1 -> each thread inverts its own dx (no workgroup memory, no barriers). +const G: u32 = {{ g }}u; +const PG: u32 = 2u; +const L0_SIGN_BIT: u32 = 0x80000000u; +const L0_IDX_MASK: u32 = 0x7fffffffu; +const NO_BUCKET: u32 = 0xffffffffu; + +@group(0) @binding(0) var sorted_bucket_list: array; +@group(0) @binding(1) var sorted_count_list: array; +@group(0) @binding(2) var offsets: array; +@group(0) @binding(3) var task_cuts: array; +@group(0) @binding(4) var l0_index: array; +@group(0) @binding(5) var point_x: array>; +@group(0) @binding(6) var point_y: array>; +@group(0) @binding(7) var bucket_sums: array>; +@group(0) @binding(8) var partials_buf: array>; +@group(0) @binding(9) var partial_dest: array; +@group(0) @binding(10) var params: vec4; + +{{#coop_scan}} +// Two TPB-wide 256-bit scratch planes for the cooperative batch inversion: +// wpre becomes the inclusive prefix products, wsuf the inclusive suffix +// products. 2 vec4 per slot. ~4 KB total at TPB=64 (vs the walker's 16 KB). +var wpre: array, TPB * 2u>; +var wsuf: array, TPB * 2u>; +var w_inv_total: array, 2u>; +{{/coop_scan}} +{{#coop_group}} +// Per-group serial batch inversion: wdx holds each thread's dx (then is +// overwritten with its inv_dx); wpx holds the running prefix products the +// group leader needs for the backward pass. 2 vec4 per slot, ~4 KB total. +var wdx: array, TPB * 2u>; +var wpx: array, TPB * 2u>; +{{/coop_group}} +{{^coop_local}} +var w_any_active: atomic; +// Mirror of the activity flag read through workgroupUniformLoad so the loop +// break is a provably-uniform value (atomic loads are not, which would make +// the in-loop barriers fail Tint's uniformity analysis). +var w_active_flag: u32; +{{/coop_local}} + +fn load_pt_x(cursor: u32) -> array { + let packed = l0_index[cursor]; + let pt = packed & L0_IDX_MASK; + let q0 = point_x[2u * pt]; + let q1 = point_x[2u * pt + 1u]; + return array(q0.x, q0.y, q0.z, q0.w, q1.x, q1.y, q1.z, q1.w); +} + +fn load_pt_y(cursor: u32) -> array { + let packed = l0_index[cursor]; + let pt = packed & L0_IDX_MASK; + let q0 = point_y[2u * pt]; + let q1 = point_y[2u * pt + 1u]; + let y = array(q0.x, q0.y, q0.z, q0.w, q1.x, q1.y, q1.z, q1.w); + if ((packed & L0_SIGN_BIT) == 0u) { return y; } + let zero = array(0u, 0u, 0u, 0u, 0u, 0u, 0u, 0u); + return fr_sub_f8(zero, y); +} + +fn store_bucket_sum(bucket_id: u32, M: u32, x_val: array, y_val: array) { + let bx = PG * bucket_id; + bucket_sums[bx + 0u] = vec4(x_val[0], x_val[1], x_val[2], x_val[3]); + bucket_sums[bx + 1u] = vec4(x_val[4], x_val[5], x_val[6], x_val[7]); + let by = PG * M + PG * bucket_id; + bucket_sums[by + 0u] = vec4(y_val[0], y_val[1], y_val[2], y_val[3]); + bucket_sums[by + 1u] = vec4(y_val[4], y_val[5], y_val[6], y_val[7]); +} + +fn store_partial(pslot: u32, bucket_id: u32, M: u32, x_val: array, y_val: array) { + let bx = PG * pslot; + partials_buf[bx + 0u] = vec4(x_val[0], x_val[1], x_val[2], x_val[3]); + partials_buf[bx + 1u] = vec4(x_val[4], x_val[5], x_val[6], x_val[7]); + let by = PG * M + PG * pslot; + partials_buf[by + 0u] = vec4(y_val[0], y_val[1], y_val[2], y_val[3]); + partials_buf[by + 1u] = vec4(y_val[4], y_val[5], y_val[6], y_val[7]); + partial_dest[pslot] = bucket_id; +} + +{{#coop_scan}} +fn wstore(arr_pre: bool, l: u32, v: array) { + let a = vec4(v[0], v[1], v[2], v[3]); + let b = vec4(v[4], v[5], v[6], v[7]); + if (arr_pre) { + wpre[2u * l + 0u] = a; + wpre[2u * l + 1u] = b; + } else { + wsuf[2u * l + 0u] = a; + wsuf[2u * l + 1u] = b; + } +} + +fn wload_pre(l: u32) -> array { + let a = wpre[2u * l + 0u]; + let b = wpre[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} + +fn wload_suf(l: u32) -> array { + let a = wsuf[2u * l + 0u]; + let b = wsuf[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} +{{/coop_scan}} +{{#coop_group}} +fn wdx_store(l: u32, v: array) { + wdx[2u * l + 0u] = vec4(v[0], v[1], v[2], v[3]); + wdx[2u * l + 1u] = vec4(v[4], v[5], v[6], v[7]); +} +fn wdx_load(l: u32) -> array { + let a = wdx[2u * l + 0u]; + let b = wdx[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} +fn wpx_store(l: u32, v: array) { + wpx[2u * l + 0u] = vec4(v[0], v[1], v[2], v[3]); + wpx[2u * l + 1u] = vec4(v[4], v[5], v[6], v[7]); +} +fn wpx_load(l: u32) -> array { + let a = wpx[2u * l + 0u]; + let b = wpx[2u * l + 1u]; + return array(a.x, a.y, a.z, a.w, b.x, b.y, b.z, b.w); +} +{{/coop_group}} + +// Single field inversion in Montgomery form (unpack -> safegcd -> repack). +fn finv8(v: array) -> array { + var lin = unpack256_to_limbs(v); + var lout = {{ inv_fn }}(lin); + let p = pack_limbs_to_256(&lout); + return array(p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7]); +} + +fn coop_is_zero_f8(v: array) -> bool { + return (v[0] | v[1] | v[2] | v[3] | v[4] | v[5] | v[6] | v[7]) == 0u; +} + +@compute @workgroup_size({{ workgroup_size }}) +fn main(@builtin(global_invocation_id) gid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let t = gid.x; + let l = lid.x; + let NUM_THREADS = params.x; + let IDLE_ANCHOR = params.y; + let M_buckets = params.z; + let M_partials = params.w; + + // Per-thread scalar state (one task per thread; acc lives in registers). + var cursor: u32 = 0u; // l0_index point position + var bucket_end: u32 = 0u; // l0 position past current bucket + var task_end_sort: u32 = 0u; // sorted index of the task's last bucket + var task_end_cur: u32 = 0u; // l0 position past the task within that bucket + var cur_sorted: u32 = 0u; // index into sorted_bucket_list + var cur_bucket: u32 = 0u; // bucket id (for bucket_sums) + var is_first: u32 = 1u; + var slot_done: u32 = 1u; // default idle (covers t >= active range) + var split_start: u32 = 0u; // current bucket shared with a prior task + var acc_x: array; + var acc_y: array; + + // Slot-layout-compatible NO_BUCKET init: clear all S partial-slot pairs so + // the shared partials buffer is well-defined for walker_partials_index + // (matches ba_stream_walker's coverage of slots 2*(t*S+k)+{0,1}). + if (t < NUM_THREADS) { + for (var k: u32 = 0u; k < S; k = k + 1u) { + partial_dest[2u * (t * S + k) + 0u] = NO_BUCKET; + partial_dest[2u * (t * S + k) + 1u] = NO_BUCKET; + } + } + + // Initialise the single task from cut 0 (start) .. cut S (end). Mirrors + // ba_stream_walker's per-slot init for the whole thread range. + if (t < NUM_THREADS) { + let cut_base = t * CUTS * 2u; + let sb = task_cuts[cut_base + 0u]; + let so = task_cuts[cut_base + 1u]; + let eb = task_cuts[cut_base + S * 2u + 0u]; + let eo = task_cuts[cut_base + S * 2u + 1u]; + + let sb_id = sorted_bucket_list[sb]; + let sb_base = offsets[sb_id]; + let sb_count = sorted_count_list[sb]; + + var eff_sorted = sb; + var eff_id = sb_id; + var eff_base = sb_base; + var eff_count = sb_count; + var start_cursor: u32; + if (so == 0u) { + start_cursor = sb_base; + split_start = 0u; + } else if (so + 1u < sb_count) { + start_cursor = sb_base + so + 1u; + split_start = 1u; + } else { + eff_sorted = sb + 1u; + eff_id = sorted_bucket_list[eff_sorted]; + eff_base = offsets[eff_id]; + eff_count = sorted_count_list[eff_sorted]; + start_cursor = eff_base; + split_start = 0u; + } + + var te_sort: u32; + var te_cur: u32; + if (eo > 0u) { + te_sort = eb; + te_cur = offsets[sorted_bucket_list[eb]] + eo + 1u; + } else if (eb > 0u) { + te_sort = eb - 1u; + let pid = sorted_bucket_list[te_sort]; + te_cur = offsets[pid] + sorted_count_list[te_sort]; + } else { + te_sort = 0u; + te_cur = 0u; + } + + cursor = start_cursor; + bucket_end = eff_base + eff_count; + task_end_sort = te_sort; + task_end_cur = te_cur; + cur_sorted = eff_sorted; + cur_bucket = eff_id; + is_first = 1u; + slot_done = 0u; + + // Empty task (region-aware): start at or past the task end. + if (eff_sorted > te_sort || (eff_sorted == te_sort && start_cursor >= te_cur)) { + slot_done = 1u; + } else { + // Single-point leading segment: a split continuation landing on a + // bucket's last point can't go through the 2-point is_first step. + var seg_end = bucket_end; + if (eff_sorted == te_sort) { seg_end = te_cur; } + if (split_start == 1u && seg_end - start_cursor == 1u) { + let px = load_pt_x(start_cursor); + let py = load_pt_y(start_cursor); + if (eff_sorted == te_sort) { + store_partial(2u * (t * S + 0u) + 1u, eff_id, M_partials, px, py); + slot_done = 1u; + } else { + store_partial(2u * (t * S + 0u) + 0u, eff_id, M_partials, px, py); + let nxt = eff_sorted + 1u; + let nxt_id = sorted_bucket_list[nxt]; + let nxt_base = offsets[nxt_id]; + cur_sorted = nxt; + cur_bucket = nxt_id; + bucket_end = nxt_base + sorted_count_list[nxt]; + cursor = nxt_base; + split_start = 0u; + if (nxt > te_sort) { slot_done = 1u; } + } + } + } + } + + // Main loop. For the cooperative modes (scan / group) every iteration is + // uniform across the workgroup: the break is decided by a workgroup-shared + // activity flag, so all threads execute the same barriers regardless of + // when their own task ends. The local mode (G==1) has no in-loop barriers, + // so each thread simply runs until its own task is done. + loop { +{{^coop_local}} + workgroupBarrier(); + if (l == 0u) { atomicStore(&w_any_active, 0u); } + workgroupBarrier(); + if (slot_done == 0u) { atomicStore(&w_any_active, 1u); } + workgroupBarrier(); + if (l == 0u) { w_active_flag = atomicLoad(&w_any_active); } + // Uniform read (implicit barrier) so the break below is uniform. + let any_active = workgroupUniformLoad(&w_active_flag); + if (any_active == 0u) { break; } +{{/coop_local}} +{{#coop_local}} + if (slot_done == 1u) { break; } +{{/coop_local}} + + // Each active thread computes its pending dx; idle threads contribute + // Montgomery one (inert in the batch product). + var dx: array; + if (slot_done == 1u) { + dx = get_r_f8(); + } else { + var p_lx: array; + var p_rx: array; + if (is_first == 1u) { + p_lx = load_pt_x(cursor); + p_rx = load_pt_x(cursor + 1u); + } else { + p_lx = acc_x; + p_rx = load_pt_x(cursor); + } + dx = fr_sub_f8(p_rx, p_lx); + // Guard the batch product: a zero dx (equal x-coords, a measure-zero + // doubling case for distinct SRS points) would zero the whole group + // product. Substitute one so the failure stays isolated. + if (coop_is_zero_f8(dx)) { dx = get_r_f8(); } + } + +{{#coop_scan}} + wstore(true, l, dx); + wstore(false, l, dx); + + // Interleaved inclusive prefix- (wpre) and suffix- (wsuf) product + // scans (Hillis-Steele). Both share the same step schedule, so fusing + // them halves the workgroup-barrier count versus two sequential scans. + for (var off: u32 = 1u; off < TPB; off = off << 1u) { + workgroupBarrier(); + var ptmp: array; + var stmp: array; + let pact = l >= off; + let sact = l + off < TPB; + if (pact) { ptmp = montgomery_product_f8(wload_pre(l - off), wload_pre(l)); } + if (sact) { stmp = montgomery_product_f8(wload_suf(l), wload_suf(l + off)); } + workgroupBarrier(); + if (pact) { wstore(true, l, ptmp); } + if (sact) { wstore(false, l, stmp); } + } + workgroupBarrier(); + + // One inversion for the whole workgroup: invert the total product. + if (l == 0u) { + let total = wload_pre(TPB - 1u); + var total20 = unpack256_to_limbs(total); + var invtot20 = {{ inv_fn }}(total20); + let invtot = pack_limbs_to_256(&invtot20); + w_inv_total[0u] = vec4(invtot[0], invtot[1], invtot[2], invtot[3]); + w_inv_total[1u] = vec4(invtot[4], invtot[5], invtot[6], invtot[7]); + } + workgroupBarrier(); + let it0 = w_inv_total[0u]; + let it1 = w_inv_total[1u]; + let inv_total = array(it0.x, it0.y, it0.z, it0.w, it1.x, it1.y, it1.z, it1.w); + + // inv_dx_l = inv_total * (prod_{jl} dx_j). + var pre_excl: array; + if (l == 0u) { pre_excl = get_r_f8(); } else { pre_excl = wload_pre(l - 1u); } + var suf_excl: array; + if (l + 1u >= TPB) { suf_excl = get_r_f8(); } else { suf_excl = wload_suf(l + 1u); } + var inv_dx = montgomery_product_f8(inv_total, pre_excl); + inv_dx = montgomery_product_f8(inv_dx, suf_excl); +{{/coop_scan}} +{{#coop_group}} + // Per-group serial Montgomery batch inversion. Each group of G threads + // shares ONE safegcd inversion; the TPB/G group leaders run their + // inversions concurrently (one per leader). Only 2 barriers per round + // regardless of G, versus the scan's 2*log2(TPB). + wdx_store(l, dx); + workgroupBarrier(); + if ((l % G) == 0u) { + // Forward prefix products over the group's G dx values. + var run = wdx_load(l); + wpx_store(l, run); + for (var i: u32 = 1u; i < G; i = i + 1u) { + run = montgomery_product_f8(run, wdx_load(l + i)); + wpx_store(l + i, run); + } + // One inversion of the group product, then the backward pass: each + // inv_dx overwrites its dx slot in wdx. + var inv = finv8(run); + for (var i: u32 = G - 1u; i >= 1u; i = i - 1u) { + let invi = montgomery_product_f8(inv, wpx_load(l + i - 1u)); + inv = montgomery_product_f8(inv, wdx_load(l + i)); + wdx_store(l + i, invi); + } + wdx_store(l, inv); + } + workgroupBarrier(); + let inv_dx = wdx_load(l); +{{/coop_group}} +{{#coop_local}} + // Each thread inverts its own dx — no workgroup memory, no barriers. + let inv_dx = finv8(dx); +{{/coop_local}} + + if (slot_done == 0u) { + var p_lx: array; + var p_ly: array; + var p_rx: array; + var p_ry: array; + if (is_first == 1u) { + p_lx = load_pt_x(cursor); + p_ly = load_pt_y(cursor); + p_rx = load_pt_x(cursor + 1u); + p_ry = load_pt_y(cursor + 1u); + cursor = cursor + 2u; + } else { + p_lx = acc_x; + p_ly = acc_y; + p_rx = load_pt_x(cursor); + p_ry = load_pt_y(cursor); + cursor = cursor + 1u; + } + + var lambda = fr_sub_f8(p_ry, p_ly); + lambda = montgomery_product_f8(lambda, inv_dx); + var r_x = montgomery_product_f8(lambda, lambda); + let x_sum = fr_add_f8(p_lx, p_rx); + r_x = fr_sub_f8(r_x, x_sum); + var r_y = fr_sub_f8(p_lx, r_x); + r_y = montgomery_product_f8(lambda, r_y); + r_y = fr_sub_f8(r_y, p_ly); + + let task_done = (cur_sorted == task_end_sort) && (cursor >= task_end_cur); + let bucket_done = cursor >= bucket_end; + + if (task_done) { + let is_partial = (split_start == 1u) || (cursor < bucket_end); + if (is_partial) { + store_partial(2u * (t * S + 0u) + 1u, cur_bucket, M_partials, r_x, r_y); + } else { + store_bucket_sum(cur_bucket, M_buckets, r_x, r_y); + } + slot_done = 1u; + } else if (bucket_done) { + if (split_start == 1u) { + store_partial(2u * (t * S + 0u) + 0u, cur_bucket, M_partials, r_x, r_y); + } else { + store_bucket_sum(cur_bucket, M_buckets, r_x, r_y); + } + let nxt = cur_sorted + 1u; + let nxt_id = sorted_bucket_list[nxt]; + let nxt_base = offsets[nxt_id]; + cur_sorted = nxt; + cur_bucket = nxt_id; + bucket_end = nxt_base + sorted_count_list[nxt]; + cursor = nxt_base; + is_first = 1u; + split_start = 0u; + } else { + acc_x = r_x; + acc_y = r_y; + is_first = 0u; + } + } + } + + {{{ recompile }}} +} diff --git a/barretenberg/ts/test-msm-xcheck.mjs b/barretenberg/ts/test-msm-xcheck.mjs new file mode 100644 index 000000000000..8368cfd1e8a1 --- /dev/null +++ b/barretenberg/ts/test-msm-xcheck.mjs @@ -0,0 +1,60 @@ +// Headless WebGPU MSM cross-check vs @noble/curves on SwiftShader (Vulkan ICD). +// Usage: node test-msm-xcheck.mjs [logn] [seed] [extraQuery] +// logn default 10 +// seed default 1 (forwarded as &seed=) +// extraQuery e.g. "accum=coop" forwarded verbatim +import { chromium } from 'playwright-core'; + +const logn = process.argv[2] || '10'; +const seed = process.argv[3] || '1'; +const extra = process.argv[4] || ''; + +const icdDir = '/opt/ms-playwright/chromium-1148/chrome-linux'; +const browser = await chromium.launch({ + headless: true, + executablePath: process.env.CHROMIUM_PATH || `${icdDir}/chrome`, + env: { + ...process.env, + VK_ICD_FILENAMES: `${icdDir}/vk_swiftshader_icd.json`, + VK_DRIVER_FILES: `${icdDir}/vk_swiftshader_icd.json`, + }, + args: [ + '--enable-unsafe-webgpu', + '--enable-features=Vulkan', + '--use-vulkan=swiftshader', + '--use-webgpu-adapter=swiftshader', + '--disable-vulkan-surface', + '--disable-gpu-sandbox', + '--no-sandbox', + '--disable-http2', + '--ignore-certificate-errors', + ], +}); +const page = await browser.newPage(); +const lines = []; +page.on('console', m => { const t = m.text(); lines.push(t); if (process.env.VERBOSE) console.log(` . ${t}`); }); +page.on('pageerror', e => { lines.push(`PAGEERR ${e.message}`); console.log(` ! ${e.message}`); }); + +const mode = process.env.AUTORUN || 'msm-noble'; +let q = `coi=1&autorun=${mode}&logn=${logn}&scalar_seed=${seed}`; +if (process.env.REPS) q += `&reps=${process.env.REPS}`; +if (extra) q += `&${extra}`; +console.log(`MSM ${mode} logn=${logn} seed=${seed} ${extra} on SwiftShader...`); +let runnerErr = null; +try { + await page.goto(`http://localhost:5173/dev/msm-webgpu/index.html?${q}`, { waitUntil: 'load', timeout: 120000 }); + await page.waitForFunction( + () => /\[autorun\] state=/.test(document.getElementById('log')?.textContent ?? ''), + null, { timeout: 600000 }); +} catch (e) { runnerErr = e.message; } + +const logText = await page.evaluate(() => { + const el = document.getElementById('log'); + return el ? Array.from(el.children).map(c => c.textContent ?? '').join('\n') : ''; +}); +console.log('─'.repeat(64)); +if (runnerErr) console.log(`runner: ${runnerErr}`); +for (const l of logText.split('\n').slice(-60)) console.log(l); +await browser.close(); +const ok = /state=done/.test(logText) && !/state=error/.test(logText); +process.exit(ok ? 0 : 1);