@@ -50,8 +50,14 @@ def __init__(self, model, config, ssd_dir: str = "./expert_cache_mlx"):
5050 self ._runtime_stats = {
5151 "resident_hits" : 0 ,
5252 "resident_misses" : 0 ,
53+ "resident_vram_hits" : 0 ,
54+ "resident_ram_hits" : 0 ,
55+ "selection_hits" : 0 ,
56+ "selection_misses" : 0 ,
5357 "prediction_hits" : 0 ,
5458 "prediction_total" : 0 ,
59+ "prefetch_queue_drops" : 0 ,
60+ "prefetch_wait_time_s" : 0.0 ,
5561 "sync_ssd_loads" : 0 ,
5662 "on_demand_load_time_s" : 0.0 ,
5763 }
@@ -75,6 +81,11 @@ def loader(eid: int) -> bool:
7581 eid in self .store ._warm
7682 and self .store ._layer_states_complete (self .store ._warm .get (eid ))
7783 )
84+ with self ._stats_lock :
85+ if was_hot or was_warm :
86+ self ._runtime_stats ["selection_hits" ] += 1
87+ else :
88+ self ._runtime_stats ["selection_misses" ] += 1
7889 t0 = time .monotonic ()
7990 if not was_hot and not was_warm :
8091 with self ._stats_lock :
@@ -85,18 +96,29 @@ def loader(eid: int) -> bool:
8596 self ._runtime_stats ["on_demand_load_time_s" ] += elapsed
8697 if ok :
8798 self ._runtime_stats ["resident_hits" if (was_hot or was_warm ) else "resident_misses" ] += 1
99+ if was_hot :
100+ self ._runtime_stats ["resident_vram_hits" ] += 1
101+ elif was_warm :
102+ self ._runtime_stats ["resident_ram_hits" ] += 1
88103 else :
89104 self ._runtime_stats ["resident_misses" ] += 1
90105 self ._runtime_stats ["prediction_total" ] += 1
91- if eid in self ._last_predicted :
106+ if ( was_hot or was_warm ) and eid in self ._last_predicted :
92107 self ._runtime_stats ["prediction_hits" ] += 1
93108 return ok
94109
95110 def touch (eid : int ) -> None :
111+ eid = int (eid )
96112 with self .store ._lock :
97113 if eid in self .store ._hot_lru :
98- if self .store ._expert_live_all_layers (int ( eid ) ):
114+ if self .store ._expert_live_all_layers (eid ):
99115 self .store ._hot_lru .move_to_end (eid )
116+ with self ._stats_lock :
117+ self ._runtime_stats ["selection_hits" ] += 1
118+ self ._runtime_stats ["resident_vram_hits" ] += 1
119+ self ._runtime_stats ["prediction_total" ] += 1
120+ if eid in self ._last_predicted :
121+ self ._runtime_stats ["prediction_hits" ] += 1
100122 else :
101123 self .store ._hot_lru .pop (eid , None )
102124
@@ -119,10 +141,44 @@ def _prefetch_loop(self):
119141 self ._prefetch_q .task_done ()
120142
121143 def _schedule_prefetch (self , expert_ids : List [int ]):
144+ if not expert_ids :
145+ return
122146 try :
123147 self ._prefetch_q .put_nowait (expert_ids )
124148 except queue .Full :
125- pass
149+ with self ._stats_lock :
150+ self ._runtime_stats ["prefetch_queue_drops" ] += 1
151+
152+ def _prefetch_and_promote_window (self , expert_ids : List [int ], timeout_s : float = 0.012 ) -> None :
153+ if not expert_ids or self .store .storage_format == "full_dram" :
154+ return
155+ pending = []
156+ for eid in dict .fromkeys (int (eid ) for eid in expert_ids ):
157+ with self .store ._lock :
158+ if eid in self .store ._hot_lru :
159+ continue
160+ if eid in self .store ._warm and self .store ._layer_states_complete (self .store ._warm .get (eid )):
161+ pending .append (eid )
162+ continue
163+ pending .append (eid )
164+ if not pending :
165+ return
166+ self ._schedule_prefetch (pending )
167+ deadline = time .monotonic () + max (0.0 , float (timeout_s ))
168+ while time .monotonic () < deadline :
169+ self ._promote_ready (pending )
170+ with self .store ._lock :
171+ ready = all (
172+ eid in self .store ._hot_lru
173+ or (eid in self .store ._warm and self .store ._layer_states_complete (self .store ._warm .get (eid )))
174+ for eid in pending
175+ )
176+ if ready :
177+ break
178+ time .sleep (0.001 )
179+ self ._promote_ready (pending )
180+ with self ._stats_lock :
181+ self ._runtime_stats ["prefetch_wait_time_s" ] += max (0.0 , time .monotonic () - (deadline - max (0.0 , float (timeout_s ))))
126182
127183 def stop (self ):
128184 self ._stop .set ()
@@ -157,8 +213,14 @@ def generate(
157213 self ._runtime_stats = {
158214 "resident_hits" : 0 ,
159215 "resident_misses" : 0 ,
216+ "resident_vram_hits" : 0 ,
217+ "resident_ram_hits" : 0 ,
218+ "selection_hits" : 0 ,
219+ "selection_misses" : 0 ,
160220 "prediction_hits" : 0 ,
161221 "prediction_total" : 0 ,
222+ "prefetch_queue_drops" : 0 ,
223+ "prefetch_wait_time_s" : 0.0 ,
162224 "sync_ssd_loads" : 0 ,
163225 "on_demand_load_time_s" : 0.0 ,
164226 }
@@ -186,6 +248,10 @@ def generate(
186248 next_token = self ._sample (logits [:, - 1 , :], temperature , top_p )
187249 activated_ids : List [int ] = []
188250 tokens = 1
251+ if scheduler is None and lookahead_probs is not None :
252+ future_ids = self ._predict_future_experts (lookahead_probs )
253+ self ._last_predicted = set (int (eid ) for eid in future_ids )
254+ self ._prefetch_and_promote_window (future_ids , timeout_s = 0.025 )
189255 yield int (next_token .item ())
190256
191257 for _ in range (max_new_tokens - 1 ):
@@ -196,8 +262,7 @@ def generate(
196262 if lookahead_probs is not None :
197263 future_ids = self ._predict_future_experts (lookahead_probs )
198264 self ._last_predicted = set (int (eid ) for eid in future_ids )
199- self ._schedule_prefetch (future_ids )
200- self ._promote_ready (future_ids )
265+ self ._prefetch_and_promote_window (future_ids )
201266 avail_masks = self ._build_avail_masks (next_token )
202267
203268 token_in = next_token .reshape (1 , 1 )
@@ -272,16 +337,22 @@ def _promote_ready(self, expert_ids: List[int]) -> None:
272337 def _runtime_stat_fields (self ) -> dict :
273338 with self ._stats_lock :
274339 stats = dict (self ._runtime_stats )
275- total = int (stats .get ("resident_hits " , 0 )) + int (stats .get ("resident_misses " , 0 ))
340+ total = int (stats .get ("selection_hits " , 0 )) + int (stats .get ("selection_misses " , 0 ))
276341 pred_total = int (stats .get ("prediction_total" , 0 ))
277342 return {
278- "resident_hit_rate" : round (float (stats .get ("resident_hits" , 0 )) / max (total , 1 ), 4 ),
279- "cache_hit_rate" : round (float (stats .get ("resident_hits" , 0 )) / max (total , 1 ), 4 ),
280- "cache_hits" : int (stats .get ("resident_hits" , 0 )),
281- "cache_misses" : int (stats .get ("resident_misses" , 0 )),
343+ "resident_hit_rate" : round (float (stats .get ("selection_hits" , 0 )) / max (total , 1 ), 4 ),
344+ "cache_hit_rate" : round (float (stats .get ("selection_hits" , 0 )) / max (total , 1 ), 4 ),
345+ "cache_hits" : int (stats .get ("selection_hits" , 0 )),
346+ "cache_misses" : int (stats .get ("selection_misses" , 0 )),
347+ "expert_selection_hits" : int (stats .get ("selection_hits" , 0 )),
348+ "expert_selection_misses" : int (stats .get ("selection_misses" , 0 )),
282349 "prediction_hit_rate" : round (float (stats .get ("prediction_hits" , 0 )) / max (pred_total , 1 ), 4 ),
283350 "prediction_hits" : int (stats .get ("prediction_hits" , 0 )),
284351 "prediction_total" : pred_total ,
352+ "resident_vram_hits" : int (stats .get ("resident_vram_hits" , 0 )),
353+ "resident_ram_hits" : int (stats .get ("resident_ram_hits" , 0 )),
354+ "prefetch_queue_drops" : int (stats .get ("prefetch_queue_drops" , 0 )),
355+ "prefetch_wait_time_s" : round (float (stats .get ("prefetch_wait_time_s" , 0.0 )), 4 ),
285356 "sync_ssd_loads" : int (stats .get ("sync_ssd_loads" , 0 )),
286357 "on_demand_loads" : int (stats .get ("resident_misses" , 0 )),
287358 "on_demand_load_time_s" : round (float (stats .get ("on_demand_load_time_s" , 0.0 )), 4 ),
@@ -309,16 +380,18 @@ def _sample(logits: mx.array, temperature: float, top_p: float) -> mx.array:
309380 @staticmethod
310381 def _memory_snapshot () -> dict :
311382 try :
312- from chronos .backend .mac_diagnostics import mlx_memory_snapshot
383+ from chronos .backend .mac_diagnostics import mlx_memory_snapshot , rss_snapshot
313384
314- return mlx_memory_snapshot ()
385+ out = dict (mlx_memory_snapshot ())
386+ out .update (rss_snapshot ())
387+ return out
315388 except Exception :
316389 return {}
317390
318391 @staticmethod
319392 def _memory_fields (snapshot : dict , suffix : str ) -> dict :
320393 out = {}
321- for key in ("mlx_active_gb" , "mlx_cache_gb" , "mlx_peak_gb" ):
394+ for key in ("rss_gb" , " mlx_active_gb" , "mlx_cache_gb" , "mlx_peak_gb" ):
322395 if key in snapshot :
323396 prefix = key [:- 3 ] if key .endswith ("_gb" ) else key
324397 out [f"{ prefix } _{ suffix } _gb" ] = snapshot [key ]
0 commit comments