Skip to content

Commit 43343d1

Browse files
authored
Add files via upload
1 parent e3b7a31 commit 43343d1

7 files changed

Lines changed: 253 additions & 54 deletions

File tree

chronos/mlx/__init__.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,6 @@
1616
- torch.compile unavailable; mx.compile() used instead
1717
- All modules are mlx.nn.Module, not torch.nn.Module
1818
"""
19-
from chronos.mlx.model import ChronosMLXModel
20-
from chronos.mlx.moe import ChronosMLXMOE
21-
from chronos.mlx.attention import MLAAttentionMLX, SlidingWindowAttentionMLX
22-
from chronos.mlx.expert_store import MLXExpertStore
23-
from chronos.mlx.inference import ChronosMLXInferenceEngine
24-
2519
__all__ = [
2620
"ChronosMLXModel",
2721
"ChronosMLXMOE",
@@ -30,3 +24,31 @@
3024
"MLXExpertStore",
3125
"ChronosMLXInferenceEngine",
3226
]
27+
28+
29+
def __getattr__(name):
30+
if name == "ChronosMLXModel":
31+
from chronos.mlx.model import ChronosMLXModel
32+
33+
return ChronosMLXModel
34+
if name == "ChronosMLXMOE":
35+
from chronos.mlx.moe import ChronosMLXMOE
36+
37+
return ChronosMLXMOE
38+
if name == "MLAAttentionMLX":
39+
from chronos.mlx.attention import MLAAttentionMLX
40+
41+
return MLAAttentionMLX
42+
if name == "SlidingWindowAttentionMLX":
43+
from chronos.mlx.attention import SlidingWindowAttentionMLX
44+
45+
return SlidingWindowAttentionMLX
46+
if name == "MLXExpertStore":
47+
from chronos.mlx.expert_store import MLXExpertStore
48+
49+
return MLXExpertStore
50+
if name == "ChronosMLXInferenceEngine":
51+
from chronos.mlx.inference import ChronosMLXInferenceEngine
52+
53+
return ChronosMLXInferenceEngine
54+
raise AttributeError(name)

chronos/mlx/inference.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,14 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
5050
self._runtime_stats = {
5151
"resident_hits": 0,
5252
"resident_misses": 0,
53+
"resident_vram_hits": 0,
54+
"resident_ram_hits": 0,
55+
"selection_hits": 0,
56+
"selection_misses": 0,
5357
"prediction_hits": 0,
5458
"prediction_total": 0,
59+
"prefetch_queue_drops": 0,
60+
"prefetch_wait_time_s": 0.0,
5561
"sync_ssd_loads": 0,
5662
"on_demand_load_time_s": 0.0,
5763
}
@@ -75,6 +81,11 @@ def loader(eid: int) -> bool:
7581
eid in self.store._warm
7682
and self.store._layer_states_complete(self.store._warm.get(eid))
7783
)
84+
with self._stats_lock:
85+
if was_hot or was_warm:
86+
self._runtime_stats["selection_hits"] += 1
87+
else:
88+
self._runtime_stats["selection_misses"] += 1
7889
t0 = time.monotonic()
7990
if not was_hot and not was_warm:
8091
with self._stats_lock:
@@ -85,18 +96,29 @@ def loader(eid: int) -> bool:
8596
self._runtime_stats["on_demand_load_time_s"] += elapsed
8697
if ok:
8798
self._runtime_stats["resident_hits" if (was_hot or was_warm) else "resident_misses"] += 1
99+
if was_hot:
100+
self._runtime_stats["resident_vram_hits"] += 1
101+
elif was_warm:
102+
self._runtime_stats["resident_ram_hits"] += 1
88103
else:
89104
self._runtime_stats["resident_misses"] += 1
90105
self._runtime_stats["prediction_total"] += 1
91-
if eid in self._last_predicted:
106+
if (was_hot or was_warm) and eid in self._last_predicted:
92107
self._runtime_stats["prediction_hits"] += 1
93108
return ok
94109

95110
def touch(eid: int) -> None:
111+
eid = int(eid)
96112
with self.store._lock:
97113
if eid in self.store._hot_lru:
98-
if self.store._expert_live_all_layers(int(eid)):
114+
if self.store._expert_live_all_layers(eid):
99115
self.store._hot_lru.move_to_end(eid)
116+
with self._stats_lock:
117+
self._runtime_stats["selection_hits"] += 1
118+
self._runtime_stats["resident_vram_hits"] += 1
119+
self._runtime_stats["prediction_total"] += 1
120+
if eid in self._last_predicted:
121+
self._runtime_stats["prediction_hits"] += 1
100122
else:
101123
self.store._hot_lru.pop(eid, None)
102124

@@ -119,10 +141,44 @@ def _prefetch_loop(self):
119141
self._prefetch_q.task_done()
120142

121143
def _schedule_prefetch(self, expert_ids: List[int]):
144+
if not expert_ids:
145+
return
122146
try:
123147
self._prefetch_q.put_nowait(expert_ids)
124148
except queue.Full:
125-
pass
149+
with self._stats_lock:
150+
self._runtime_stats["prefetch_queue_drops"] += 1
151+
152+
def _prefetch_and_promote_window(self, expert_ids: List[int], timeout_s: float = 0.012) -> None:
153+
if not expert_ids or self.store.storage_format == "full_dram":
154+
return
155+
pending = []
156+
for eid in dict.fromkeys(int(eid) for eid in expert_ids):
157+
with self.store._lock:
158+
if eid in self.store._hot_lru:
159+
continue
160+
if eid in self.store._warm and self.store._layer_states_complete(self.store._warm.get(eid)):
161+
pending.append(eid)
162+
continue
163+
pending.append(eid)
164+
if not pending:
165+
return
166+
self._schedule_prefetch(pending)
167+
deadline = time.monotonic() + max(0.0, float(timeout_s))
168+
while time.monotonic() < deadline:
169+
self._promote_ready(pending)
170+
with self.store._lock:
171+
ready = all(
172+
eid in self.store._hot_lru
173+
or (eid in self.store._warm and self.store._layer_states_complete(self.store._warm.get(eid)))
174+
for eid in pending
175+
)
176+
if ready:
177+
break
178+
time.sleep(0.001)
179+
self._promote_ready(pending)
180+
with self._stats_lock:
181+
self._runtime_stats["prefetch_wait_time_s"] += max(0.0, time.monotonic() - (deadline - max(0.0, float(timeout_s))))
126182

127183
def stop(self):
128184
self._stop.set()
@@ -157,8 +213,14 @@ def generate(
157213
self._runtime_stats = {
158214
"resident_hits": 0,
159215
"resident_misses": 0,
216+
"resident_vram_hits": 0,
217+
"resident_ram_hits": 0,
218+
"selection_hits": 0,
219+
"selection_misses": 0,
160220
"prediction_hits": 0,
161221
"prediction_total": 0,
222+
"prefetch_queue_drops": 0,
223+
"prefetch_wait_time_s": 0.0,
162224
"sync_ssd_loads": 0,
163225
"on_demand_load_time_s": 0.0,
164226
}
@@ -186,6 +248,10 @@ def generate(
186248
next_token = self._sample(logits[:, -1, :], temperature, top_p)
187249
activated_ids: List[int] = []
188250
tokens = 1
251+
if scheduler is None and lookahead_probs is not None:
252+
future_ids = self._predict_future_experts(lookahead_probs)
253+
self._last_predicted = set(int(eid) for eid in future_ids)
254+
self._prefetch_and_promote_window(future_ids, timeout_s=0.025)
189255
yield int(next_token.item())
190256

191257
for _ in range(max_new_tokens - 1):
@@ -196,8 +262,7 @@ def generate(
196262
if lookahead_probs is not None:
197263
future_ids = self._predict_future_experts(lookahead_probs)
198264
self._last_predicted = set(int(eid) for eid in future_ids)
199-
self._schedule_prefetch(future_ids)
200-
self._promote_ready(future_ids)
265+
self._prefetch_and_promote_window(future_ids)
201266
avail_masks = self._build_avail_masks(next_token)
202267

203268
token_in = next_token.reshape(1, 1)
@@ -272,16 +337,22 @@ def _promote_ready(self, expert_ids: List[int]) -> None:
272337
def _runtime_stat_fields(self) -> dict:
273338
with self._stats_lock:
274339
stats = dict(self._runtime_stats)
275-
total = int(stats.get("resident_hits", 0)) + int(stats.get("resident_misses", 0))
340+
total = int(stats.get("selection_hits", 0)) + int(stats.get("selection_misses", 0))
276341
pred_total = int(stats.get("prediction_total", 0))
277342
return {
278-
"resident_hit_rate": round(float(stats.get("resident_hits", 0)) / max(total, 1), 4),
279-
"cache_hit_rate": round(float(stats.get("resident_hits", 0)) / max(total, 1), 4),
280-
"cache_hits": int(stats.get("resident_hits", 0)),
281-
"cache_misses": int(stats.get("resident_misses", 0)),
343+
"resident_hit_rate": round(float(stats.get("selection_hits", 0)) / max(total, 1), 4),
344+
"cache_hit_rate": round(float(stats.get("selection_hits", 0)) / max(total, 1), 4),
345+
"cache_hits": int(stats.get("selection_hits", 0)),
346+
"cache_misses": int(stats.get("selection_misses", 0)),
347+
"expert_selection_hits": int(stats.get("selection_hits", 0)),
348+
"expert_selection_misses": int(stats.get("selection_misses", 0)),
282349
"prediction_hit_rate": round(float(stats.get("prediction_hits", 0)) / max(pred_total, 1), 4),
283350
"prediction_hits": int(stats.get("prediction_hits", 0)),
284351
"prediction_total": pred_total,
352+
"resident_vram_hits": int(stats.get("resident_vram_hits", 0)),
353+
"resident_ram_hits": int(stats.get("resident_ram_hits", 0)),
354+
"prefetch_queue_drops": int(stats.get("prefetch_queue_drops", 0)),
355+
"prefetch_wait_time_s": round(float(stats.get("prefetch_wait_time_s", 0.0)), 4),
285356
"sync_ssd_loads": int(stats.get("sync_ssd_loads", 0)),
286357
"on_demand_loads": int(stats.get("resident_misses", 0)),
287358
"on_demand_load_time_s": round(float(stats.get("on_demand_load_time_s", 0.0)), 4),
@@ -309,16 +380,18 @@ def _sample(logits: mx.array, temperature: float, top_p: float) -> mx.array:
309380
@staticmethod
310381
def _memory_snapshot() -> dict:
311382
try:
312-
from chronos.backend.mac_diagnostics import mlx_memory_snapshot
383+
from chronos.backend.mac_diagnostics import mlx_memory_snapshot, rss_snapshot
313384

314-
return mlx_memory_snapshot()
385+
out = dict(mlx_memory_snapshot())
386+
out.update(rss_snapshot())
387+
return out
315388
except Exception:
316389
return {}
317390

318391
@staticmethod
319392
def _memory_fields(snapshot: dict, suffix: str) -> dict:
320393
out = {}
321-
for key in ("mlx_active_gb", "mlx_cache_gb", "mlx_peak_gb"):
394+
for key in ("rss_gb", "mlx_active_gb", "mlx_cache_gb", "mlx_peak_gb"):
322395
if key in snapshot:
323396
prefix = key[:-3] if key.endswith("_gb") else key
324397
out[f"{prefix}_{suffix}_gb"] = snapshot[key]

chronos/mlx/moe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def __call__(
120120

121121
if python_avail is not None:
122122
selected = bool(mx.any(tok_mask).item())
123+
if not selected:
124+
continue
123125
is_avail = i in python_avail
124126
if selected and not is_avail:
125127
loader = getattr(self, "runtime_on_demand_loader", None)

chronos/mlx/training/trainer.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@
55
import time
66
from dataclasses import dataclass
77

8-
import mlx.core as mx
9-
import mlx.nn as nn
10-
import mlx.optimizers as optim
118
import torch
12-
from mlx.utils import tree_map
139

1410
from chronos.model.model_chronos import ChronosForCausalLM
1511
from chronos.model.checkpoint import (
@@ -18,11 +14,26 @@
1814
load_state_dict_controlled,
1915
save_state_dict_with_config,
2016
)
21-
from chronos.mlx.model import ChronosMLXModel
22-
from chronos.mlx.moe import ChronosMLXMOE
23-
from chronos.mlx.training.io import mlx_state_to_torch
2417
from chronos.trainer.optim_utils import get_lr
2518

19+
try:
20+
import mlx.core as mx
21+
import mlx.nn as nn
22+
import mlx.optimizers as optim
23+
from mlx.utils import tree_map
24+
25+
from chronos.mlx.model import ChronosMLXModel
26+
from chronos.mlx.moe import ChronosMLXMOE
27+
from chronos.mlx.training.io import mlx_state_to_torch
28+
except ModuleNotFoundError:
29+
mx = None
30+
nn = None
31+
optim = None
32+
tree_map = None
33+
ChronosMLXModel = None
34+
ChronosMLXMOE = None
35+
mlx_state_to_torch = None
36+
2637

2738
def _ids_to_mx(x) -> mx.array:
2839
if isinstance(x, mx.array):
@@ -494,7 +505,7 @@ def _planned_total_steps(data_iter, epochs: int, max_steps) -> int:
494505
def _normalize_mlx_dtype_name(dtype_name: str | None) -> str:
495506
value = (dtype_name or "auto").strip().lower()
496507
if value in {"auto", ""}:
497-
return "bfloat16" if hasattr(mx, "bfloat16") else "float32"
508+
return "bfloat16" if mx is not None and hasattr(mx, "bfloat16") else "float32"
498509
if value in {"fp16", "float16", "half"}:
499510
return "float16"
500511
if value in {"bf16", "bfloat16"}:

chronos/trainer/device_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,7 @@ def dataloader_kwargs(
284284
.lower()
285285
in {"1", "true", "yes", "on"}
286286
)
287-
force_single_process = (
288-
sys.platform == "darwin"
289-
and metal_backend
290-
and not allow_metal_workers
291-
)
287+
force_single_process = metal_backend and not allow_metal_workers
292288
if num_workers in (None, "", "auto"):
293289
workers = 0 if force_single_process else max(1, min(4, _physical_cores() // 4))
294290
if device_type == "xpu":

tests/test_smoke.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,12 +642,16 @@ def test_configure_cpu_threads_overrides_single_thread_env(monkeypatch):
642642

643643

644644
def test_configure_cpu_threads_ignores_stale_chronos_env_by_default(monkeypatch):
645-
import psutil
646645
from chronos.trainer.device_utils import configure_cpu_threads, cpu_thread_snapshot
647646

648647
monkeypatch.setenv("CHRONOS_CPU_THREADS", "1")
649648
monkeypatch.setenv("OMP_NUM_THREADS", "1")
650-
physical = int(psutil.cpu_count(logical=False) or os.cpu_count() or 1)
649+
try:
650+
import psutil
651+
652+
physical = int(psutil.cpu_count(logical=False) or os.cpu_count() or 1)
653+
except Exception:
654+
physical = int(os.cpu_count() or 1)
651655
threads = configure_cpu_threads("auto", budget_percent=100)
652656
snap = cpu_thread_snapshot()
653657
assert threads == physical
@@ -2012,6 +2016,8 @@ def test_inference_stats_helpers_are_structured():
20122016
"resident_hit_rate": 0.75,
20132017
"prediction_hit_rate": 0.5,
20142018
"on_demand_loads": 1,
2019+
"prefetch_queue_drops": 0,
2020+
"prefetch_wait_time_s": 0.012,
20152021
"async_cold_miss_prefetches": 2,
20162022
"sync_ssd_loads": 1,
20172023
"miss_policy": "on_demand",
@@ -2051,7 +2057,8 @@ def test_inference_stats_helpers_are_structured():
20512057
assert "Setup RSS delta" in md and "Prefill RSS delta" in md and "Decode RSS" in md
20522058
assert "1.200 GB" in md and "2.030 GB" in md
20532059
assert "Load budget" in md
2054-
assert "On-demand loads" in md and "Async misses" in md and "Predict hit" in md
2060+
assert "On-demand loads" in md and "Prefetch wait" in md and "Predict hit" in md
2061+
assert "Expert hits/misses" in md
20552062
assert "lazy_offload" in md and "full_dram" in md
20562063
assert set(df.columns) == {"metric", "mode", "x", "value", "normalized_value", "unit"}
20572064
assert set(df["mode"]) == {"lazy_offload", "full_dram"}
@@ -2074,6 +2081,9 @@ def test_inference_offload_budget_caps_at_125_percent():
20742081
assert budget["effective_vram_expert_budget"] == 6
20752082
assert budget["effective_ram_expert_budget"] == 6
20762083
assert budget["routing_top_k"] == 3
2084+
low = _bounded_offload_expert_budget(cfg, 0.10)
2085+
assert low["effective_expert_budget"] == 1
2086+
assert low["effective_ram_expert_budget"] == 1
20772087
assert budget["num_moe_layers"] == 8
20782088

20792089
cfg2 = ChronosConfig(num_experts=64, num_experts_per_tok=4, num_hidden_layers=8)
@@ -2090,6 +2100,7 @@ def test_inference_ram_load_ratio_accepts_custom_values():
20902100
RAM_LOAD_RATIO_CHOICES,
20912101
RAM_LOAD_SWEEP_RATIOS,
20922102
_bounded_offload_expert_budget,
2103+
_clone_model_cfg,
20932104
_normalize_ram_load_ratio,
20942105
)
20952106

@@ -2101,15 +2112,27 @@ def test_inference_ram_load_ratio_accepts_custom_values():
21012112
custom = _bounded_offload_expert_budget(cfg, "0.33")
21022113
assert custom["requested_ram_load_ratio"] == 0.33
21032114
assert custom["effective_expert_budget"] == 11
2104-
assert custom["effective_ram_expert_budget"] == 32
2115+
assert custom["effective_ram_expert_budget"] == 11
21052116

21062117
custom_high = _bounded_offload_expert_budget(cfg, "1.10")
21072118
assert custom_high["requested_ram_load_ratio"] == 1.1
21082119
assert custom_high["effective_expert_budget"] == 36
2109-
assert custom_high["effective_ram_expert_budget"] == 40
2120+
assert custom_high["effective_ram_expert_budget"] == 36
21102121

21112122
assert _normalize_ram_load_ratio("not-a-number") == 1.0
21122123

2124+
small = ChronosConfig(
2125+
num_experts=6,
2126+
num_experts_per_tok=3,
2127+
num_hidden_layers=8,
2128+
recommended_resident_experts=2,
2129+
)
2130+
budgets = [
2131+
_bounded_offload_expert_budget(_clone_model_cfg(small), ratio)["effective_expert_budget"]
2132+
for ratio in RAM_LOAD_SWEEP_RATIOS
2133+
]
2134+
assert budgets == [1, 2, 2, 2, 3, 5, 5, 6, 6, 6, 6]
2135+
21132136

21142137
def test_generate_api_returns_plain_json_with_chart_records():
21152138
from ui.tabs import inference_tab as mod

0 commit comments

Comments
 (0)