diff --git a/prototype/test_correctness_suite.py b/prototype/test_correctness_suite.py index c77da7c..e19125b 100644 --- a/prototype/test_correctness_suite.py +++ b/prototype/test_correctness_suite.py @@ -12,6 +12,7 @@ """ import sys +import time import numpy as np import pytest from prototype.model_tools import ToyModel @@ -319,19 +320,28 @@ def test_throughput_consistency(self): for bs in batch_sizes: x = np.random.randn(bs, 8).astype(np.float32) - # Measure time for single-node - start = np.datetime64('now') - _ = single_node_forward(model, x[0]) # First element only - end = np.datetime64('now') + # Warm up to avoid first-run JIT/cache effects + for _ in range(3): + _ = single_node_forward(model, x[0]) - latency = (end - start).astype(np.float64) * 1e9 # nanoseconds + # Take best-of-N to reduce OS scheduling noise + samples = [] + for _ in range(10): + start = time.perf_counter_ns() + _ = single_node_forward(model, x[0]) # First element only + end = time.perf_counter_ns() + samples.append(end - start) - latencies.append(latency) + best_latency = float(min(samples)) # nanoseconds + latencies.append(best_latency) # Latencies should be within 20% of each other (allowing for variance) min_latency = min(latencies) max_latency = max(latencies) + if min_latency < 100: + pytest.skip("Timer resolution too low to measure latency reliably") + assert (max_latency - min_latency) / min_latency < 0.2, \ f"Latency variance too high: {min_latency} vs {max_latency}"