From 208fd03be27af85fd5063bc3b9863ac6bec6357c Mon Sep 17 00:00:00 2001 From: Natalia Date: Sun, 22 Feb 2026 02:16:22 -0800 Subject: [PATCH 1/3] Fix batch-and-skip benchmark exploit via per-call timing and correctness checks The current eval times all 15 custom_kernel() calls as a single batch and divides by 15. A malicious submission can exploit this by deferring all work to one call (batching 15 problems into a single kernel launch) and making the other 14 calls no-ops, reporting ~1/15th of the real per-call cost. Cloning data alone (as proposed in #102) does not fully prevent this -- a shape-matching fallback path can still collect new data objects and batch them. This fix: - Clones data each timing iteration (prevents object-identity caching) - Times each call individually with its own CUDA events and GPU sync (prevents amortization across calls) - Checks correctness after each individual call in recheck/leaderboard mode (catches deferred-computation exploits that return uncomputed tensors) - Uses a local seed variable instead of mutating test.args - Fixes the recheck indentation bug where only the last call was checked --- .../nvidia/eval_better_bench_grouped_gemm.py | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/problems/nvidia/eval_better_bench_grouped_gemm.py b/problems/nvidia/eval_better_bench_grouped_gemm.py index 09b5279..d41a2f4 100644 --- a/problems/nvidia/eval_better_bench_grouped_gemm.py +++ b/problems/nvidia/eval_better_bench_grouped_gemm.py @@ -240,12 +240,16 @@ def _run_single_benchmark( durations = [] data_list = [] - # generate input data once + # generate input data once (local seed avoids mutating test.args) + local_seed = test.args.get("seed", None) for i in range(NUM_ITERATIONS_PER_BENCHMARK): - if "seed" in test.args: - test.args["seed"] += 42 - data = generate_input(**test.args) + if local_seed is not None: + local_seed += 42 + args = {**test.args, "seed": local_seed} + else: + args = test.args + data = generate_input(**args) data_list.append(data) check_copy = _clone_data(data_list) @@ -263,35 +267,40 @@ def _run_single_benchmark( if not good: return message - # now, do multiple timing runs without further correctness testing - # there is an upper bound of 200 runs, and a lower bound of 3 runs; - # otherwise, we repeat until we either measure at least 10 full seconds, - # or the relative error of the mean is below 1%. + # Timing: individual per-call measurement with GPU sync between calls. + # This prevents "batch-and-skip" exploits where a submission defers all + # work to one call and returns cached/uncomputed results for the rest. + # Data is cloned each iteration to prevent object-identity caching. bm_start_time = time.perf_counter_ns() for i in range(max_repeats): + iteration_data = _clone_data(data_list) torch.cuda.synchronize() + clear_l2_cache() + per_call_durations = [] outputs = [] - clear_l2_cache() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - for data in data_list: + for j, data in enumerate(iteration_data): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + per_call_durations.append( + start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + ) outputs.append(output) - end_event.record() - torch.cuda.synchronize() - duration = ( - start_event.elapsed_time(end_event) / NUM_ITERATIONS_PER_BENCHMARK - ) * 1e6 # Convert ms to ns - if recheck: - for reference_output, custom_output in zip(check_copy, outputs): - good, message = check_implementation(reference_output, custom_output) - if not good: - return message + # Per-call correctness check catches deferred-computation exploits: + # if a submission skips the kernel and returns uncomputed tensors, + # the check fails immediately. + if recheck: + good, message = check_implementation(check_copy[j], output) + if not good: + return message + duration = sum(per_call_durations) / NUM_ITERATIONS_PER_BENCHMARK durations.append(duration) total_bm_duration = time.perf_counter_ns() - bm_start_time From 3d389a2561ff72d55848df8e60155802e7f4fa40 Mon Sep 17 00:00:00 2001 From: Luc Chartier Date: Sun, 22 Feb 2026 07:10:08 -0500 Subject: [PATCH 2/3] Improve grouped GEMM eval anti-cheat checks --- .../nvidia/eval_better_bench_grouped_gemm.py | 111 +++++++++++++----- 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/problems/nvidia/eval_better_bench_grouped_gemm.py b/problems/nvidia/eval_better_bench_grouped_gemm.py index d41a2f4..c1a394b 100644 --- a/problems/nvidia/eval_better_bench_grouped_gemm.py +++ b/problems/nvidia/eval_better_bench_grouped_gemm.py @@ -1,11 +1,13 @@ import base64 import dataclasses import multiprocessing +import random import re import time import os import sys import math +import random # Disable CuTe DSL file caching for more stable benchmarking os.environ["CUTE_DSL_DISABLE_FILE_CACHING"] = "1" @@ -240,7 +242,7 @@ def _run_single_benchmark( durations = [] data_list = [] - # generate input data once (local seed avoids mutating test.args) + # generate input data once local_seed = test.args.get("seed", None) for i in range(NUM_ITERATIONS_PER_BENCHMARK): @@ -253,8 +255,14 @@ def _run_single_benchmark( data_list.append(data) check_copy = _clone_data(data_list) - - # first, one obligatory correctness check + # Deterministic but hidden probe stream. + # In benchmark mode we use randomized call windows and sparse probes. + # In leaderboard mode we do one full sweep up front, then lightweight probes. + probe_seed = _combine(int(test.args.get("seed", 0) or 0), 0x4D455452) + probe_rng = random.Random(probe_seed) + full_calls = len(data_list) + + # First, one obligatory correctness check on fresh clones. outputs = [] try: for data in data_list: @@ -267,45 +275,88 @@ def _run_single_benchmark( if not good: return message - # Timing: individual per-call measurement with GPU sync between calls. - # This prevents "batch-and-skip" exploits where a submission defers all - # work to one call and returns cached/uncomputed results for the rest. + # Timing: per-call intervals captured with CUDA events and one sync. + # We randomize window length/order in benchmark mode to break fixed-N exploits. # Data is cloned each iteration to prevent object-identity caching. bm_start_time = time.perf_counter_ns() for i in range(max_repeats): + # Clone and shuffle data before timing to prevent both + # object-identity caching and call-order caching exploits iteration_data = _clone_data(data_list) + shuffle_order = list(range(len(iteration_data))) + random.shuffle(shuffle_order) + iteration_data = [iteration_data[j] for j in shuffle_order] + torch.cuda.synchronize() - clear_l2_cache() - per_call_durations = [] + if recheck: + call_indices = list(range(full_calls)) + else: + call_indices = list(range(full_calls)) + probe_rng.shuffle(call_indices) + min_calls = max(4, full_calls - 6) + n_calls = probe_rng.randint(min_calls, full_calls) + call_indices = call_indices[:n_calls] + outputs = [] - for j, data in enumerate(iteration_data): - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - start_event.record() - output = custom_kernel(data) - end_event.record() - torch.cuda.synchronize() - per_call_durations.append( - start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns - ) - outputs.append(output) + events = [torch.cuda.Event(enable_timing=True) for _ in range(len(call_indices) + 1)] + if recheck: + integrity_repeat = (i == 0) or (i % 20 == 0) + else: + integrity_repeat = (i < 3) or (i % 25 == 0) - # Per-call correctness check catches deferred-computation exploits: - # if a submission skips the kernel and returns uncomputed tensors, - # the check fails immediately. - if recheck: - good, message = check_implementation(check_copy[j], output) + if integrity_repeat and len(call_indices) <= 1: + in_loop_probe_pos = 0 if call_indices else None + elif integrity_repeat: + # Probe before last call to expose deferred-until-last behavior. + in_loop_probe_pos = probe_rng.randrange(0, len(call_indices) - 1) + else: + in_loop_probe_pos = None + + events[0].record() + for k, idx in enumerate(call_indices): + output = custom_kernel(iteration_data[idx]) + outputs.append((idx, output)) + events[k + 1].record() + + # In-loop probe check catches deferred-until-last exploits that would + # otherwise pass if outputs are only validated after the final call. + if in_loop_probe_pos is not None and k == in_loop_probe_pos: + torch.cuda.synchronize() + good, message = check_implementation(check_copy[idx], output) if not good: return message + torch.cuda.synchronize() + + per_call_durations = [ + events[k].elapsed_time(events[k + 1]) * 1e6 for k in range(len(call_indices)) + ] - duration = sum(per_call_durations) / NUM_ITERATIONS_PER_BENCHMARK - durations.append(duration) + # Correctness policy: + # - benchmark: sparse hidden integrity repeats + randomized windows/order. + # - leaderboard: sparse integrity repeats; first repeat gets full sweep. + if recheck: + if i == 0: + check_positions = list(range(len(outputs))) + else: + check_positions = [] + else: + check_positions = [] + + for pos in check_positions: + idx, output = outputs[pos] + good, message = check_implementation(check_copy[idx], output) + if not good: + return message + + duration = sum(per_call_durations) / len(call_indices) + if not integrity_repeat: + durations.append(duration) total_bm_duration = time.perf_counter_ns() - bm_start_time if ( - i > 1 and total_bm_duration > 1e8 + len(durations) > 1 and total_bm_duration > 1e8 ): # at least 2 runs, and at least 100 ms total time stats = calculate_stats(durations) # stop if either @@ -319,6 +370,9 @@ def _run_single_benchmark( ): break + if not durations: + return "benchmark produced no timing samples" + return calculate_stats(durations) @@ -527,8 +581,9 @@ def main(): break logger.log("check", "pass" if passed else "fail") + return 0 if passed else 112 elif mode == "profile": - run_profiling(logger, pool, tests) + return run_profiling(logger, pool, tests) else: # TODO: Implement script mode return 2 From 340e48ebeac4d7f1fcb011b5b2adb557b941e920 Mon Sep 17 00:00:00 2001 From: Luc Chartier Date: Sun, 22 Feb 2026 07:45:34 -0500 Subject: [PATCH 3/3] Add fingerprint audit for deferred output mutation --- .../nvidia/eval_better_bench_grouped_gemm.py | 134 ++++++++++++++++-- 1 file changed, 122 insertions(+), 12 deletions(-) diff --git a/problems/nvidia/eval_better_bench_grouped_gemm.py b/problems/nvidia/eval_better_bench_grouped_gemm.py index c1a394b..92424a0 100644 --- a/problems/nvidia/eval_better_bench_grouped_gemm.py +++ b/problems/nvidia/eval_better_bench_grouped_gemm.py @@ -175,6 +175,88 @@ def _clone_data(data): return data +def _collect_output_tensors(output): + """Collect tensors from nested output structure in deterministic order.""" + tensors = [] + + def _walk(x): + if isinstance(x, torch.Tensor): + tensors.append(x) + elif isinstance(x, (list, tuple)): + for y in x: + _walk(y) + elif isinstance(x, dict): + for k in sorted(x.keys()): + _walk(x[k]) + + _walk(output) + return tensors + + +def _make_fingerprint_plan(output, gen, samples_per_tensor: int = 256): + """ + Build a secret sampled hash plan for this output structure. + """ + tensors = _collect_output_tensors(output) + if not tensors: + return [] + + plan = [] + for t in tensors: + n = int(t.numel()) + s = min(samples_per_tensor, n) + if s <= 0: + plan.append((0, None, None, None)) + continue + idx = torch.randint(0, n, (s,), generator=gen, device=t.device, dtype=torch.int64) + w1 = torch.randint( + -(1 << 20), (1 << 20), (s,), generator=gen, device=t.device, dtype=torch.int32 + ).to(torch.float64) + w2 = torch.randint( + -(1 << 20), (1 << 20), (s,), generator=gen, device=t.device, dtype=torch.int32 + ).to(torch.float64) + plan.append((n, idx, w1, w2)) + return plan + + +def _fingerprint_output(output, plan): + """ + Compute a lightweight sampled fingerprint of output tensor contents. + + Returns two device scalars (h1, h2). If output changed post-return, the + fingerprint almost certainly changes too. + """ + tensors = _collect_output_tensors(output) + if len(tensors) != len(plan): + raise ValueError( + f"output structure changed: expected {len(plan)} tensors, got {len(tensors)}" + ) + + if not tensors: + z = torch.zeros((), dtype=torch.float64) + return z, z + + device = tensors[0].device + h1 = torch.zeros((), device=device, dtype=torch.float64) + h2 = torch.zeros((), device=device, dtype=torch.float64) + + for t, (expected_n, idx, w1, w2) in zip(tensors, plan): + n = int(t.numel()) + if n != expected_n: + raise ValueError(f"output tensor size changed: expected {expected_n}, got {n}") + if expected_n == 0: + continue + vals = t.reshape(-1).index_select(0, idx).to(torch.float64) + vals = torch.nan_to_num(vals, nan=0.0, posinf=1e6, neginf=-1e6) + h1 = h1 + (vals * w1).sum(dtype=torch.float64) + h2 = h2 + (vals * w2).sum(dtype=torch.float64) + return h1, h2 + + +def _fingerprint_equal(a, b) -> bool: + return torch.equal(a[0], b[0]) and torch.equal(a[1], b[1]) + + def _run_single_test(test: TestCase): """ Runs a single test case. Do not call directly @@ -261,6 +343,9 @@ def _run_single_benchmark( probe_seed = _combine(int(test.args.get("seed", 0) or 0), 0x4D455452) probe_rng = random.Random(probe_seed) full_calls = len(data_list) + fp_gen = torch.Generator(device="cuda") + fp_seed = _combine(probe_seed, 0xF1A9E5) & ((1 << 63) - 1) + fp_gen.manual_seed(fp_seed) # First, one obligatory correctness check on fresh clones. outputs = [] @@ -274,6 +359,10 @@ def _run_single_benchmark( good, message = check_implementation(reference_output, custom_output) if not good: return message + try: + fingerprint_plans = [_make_fingerprint_plan(out, fp_gen) for out in outputs] + except Exception as E: + return f"fingerprint plan build failed: {E}" # Timing: per-call intervals captured with CUDA events and one sync. # We randomize window length/order in benchmark mode to break fixed-N exploits. @@ -290,21 +379,27 @@ def _run_single_benchmark( torch.cuda.synchronize() + if recheck: + integrity_repeat = (i == 0) or (i % 20 == 0) + else: + integrity_repeat = (i < 3) or (i % 25 == 0) + if recheck: call_indices = list(range(full_calls)) else: call_indices = list(range(full_calls)) probe_rng.shuffle(call_indices) - min_calls = max(4, full_calls - 6) - n_calls = probe_rng.randint(min_calls, full_calls) + if integrity_repeat: + # Integrity repeats must exercise the full call window so + # flush-at-N exploits cannot hide behind short random windows. + n_calls = full_calls + else: + min_calls = max(4, full_calls - 6) + n_calls = probe_rng.randint(min_calls, full_calls) call_indices = call_indices[:n_calls] outputs = [] events = [torch.cuda.Event(enable_timing=True) for _ in range(len(call_indices) + 1)] - if recheck: - integrity_repeat = (i == 0) or (i % 20 == 0) - else: - integrity_repeat = (i < 3) or (i % 25 == 0) if integrity_repeat and len(call_indices) <= 1: in_loop_probe_pos = 0 if call_indices else None @@ -314,21 +409,36 @@ def _run_single_benchmark( else: in_loop_probe_pos = None + probe_snapshot = None events[0].record() for k, idx in enumerate(call_indices): output = custom_kernel(iteration_data[idx]) outputs.append((idx, output)) events[k + 1].record() - # In-loop probe check catches deferred-until-last exploits that would - # otherwise pass if outputs are only validated after the final call. + # Snapshot output state immediately after return; compare again after + # the full window to detect post-return deferred writes. if in_loop_probe_pos is not None and k == in_loop_probe_pos: - torch.cuda.synchronize() - good, message = check_implementation(check_copy[idx], output) - if not good: - return message + try: + fp_before = _fingerprint_output(output, fingerprint_plans[idx]) + except Exception as E: + return f"fingerprint snapshot failed: {E}" + probe_snapshot = (idx, output, fp_before) torch.cuda.synchronize() + if probe_snapshot is not None: + idx, probe_output, fp_before = probe_snapshot + try: + fp_after = _fingerprint_output(probe_output, fingerprint_plans[idx]) + except Exception as E: + return f"fingerprint verify failed: {E}" + torch.cuda.synchronize() + if not _fingerprint_equal(fp_before, fp_after): + return ( + "detected deferred/cross-call output mutation " + f"(call_index={idx}, window_calls={len(call_indices)})" + ) + per_call_durations = [ events[k].elapsed_time(events[k + 1]) * 1e6 for k in range(len(call_indices)) ]