Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions prototype/test_correctness_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""

import sys
import time
import numpy as np
import pytest
from prototype.model_tools import ToyModel
Expand Down Expand Up @@ -319,19 +320,28 @@ def test_throughput_consistency(self):
for bs in batch_sizes:
x = np.random.randn(bs, 8).astype(np.float32)

# Measure time for single-node
start = np.datetime64('now')
_ = single_node_forward(model, x[0]) # First element only
end = np.datetime64('now')
# Warm up to avoid first-run JIT/cache effects
for _ in range(3):
_ = single_node_forward(model, x[0])

latency = (end - start).astype(np.float64) * 1e9 # nanoseconds
# Take best-of-N to reduce OS scheduling noise
samples = []
for _ in range(10):
start = time.perf_counter_ns()
_ = single_node_forward(model, x[0]) # First element only
end = time.perf_counter_ns()
samples.append(end - start)

latencies.append(latency)
best_latency = float(min(samples)) # nanoseconds
latencies.append(best_latency)

# Latencies should be within 20% of each other (allowing for variance)
min_latency = min(latencies)
max_latency = max(latencies)

if min_latency < 100:
pytest.skip("Timer resolution too low to measure latency reliably")

assert (max_latency - min_latency) / min_latency < 0.2, \
f"Latency variance too high: {min_latency} vs {max_latency}"

Expand Down
Loading