From c8950c0cc41e55b09579037b1373f23406fc22df Mon Sep 17 00:00:00 2001 From: jl33ai Date: Fri, 29 May 2026 09:27:26 -0700 Subject: [PATCH] preallocate decoder hot path scratch _process_lfp_timestamp runs at ~167hz (180-sample bin at 30khz spike clock) and was allocating per tick: - two np.zeros(cred_int_bufsize) scratch arrays - a logical_and mask - two no-op atleast_2d wrappers around already-2d slices that allocation churn is the kind of thing that lets generational gc accumulate enough pressure to produce 50-100ms tail latency spikes, which would explain some of the worst case timing the lab has seen on the python side. changes: - enc_cred_intervals, enc_argmaxes, and the spike mask are now instance attrs allocated once in __init__ and reused via .fill(0) / out= - dropped the two atleast_2d calls since boolean indexing a 2d array with a 1d mask already returns 2d - hoisted self.p[...] lookups out of the inner loop - swapped msg.tobytes() for [msg, MPI.BYTE] in the three send paths so MPI gets the numpy buffer directly. wire format unchanged so receivers don't need any change. no algorithmic change. semantics preserved. same pass is worth doing in encoder_process and ripple_process but kept this commit scoped to the decoder. --- realtime_decoder/decoder_process.py | 74 ++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 22 deletions(-) diff --git a/realtime_decoder/decoder_process.py b/realtime_decoder/decoder_process.py index 2823924..c5ba64d 100644 --- a/realtime_decoder/decoder_process.py +++ b/realtime_decoder/decoder_process.py @@ -3,6 +3,8 @@ import glob import numpy as np +from mpi4py import MPI + from realtime_decoder import ( base, utils, position, messages, transitions, binary_record, taskstate @@ -18,11 +20,17 @@ class DecoderMPISendInterface(base.StandardMPISendInterface): def __init__(self, comm, rank, config): super().__init__(comm, rank, config) + # NOTE: each send path used to call msg.tobytes() per tick, which + # allocates a fresh bytes object every time. Sending [msg, MPI.BYTE] + # hands MPI the numpy buffer directly — zero-copy, no GC pressure. + # Wire format is unchanged (raw bytes), so receivers built around + # `bytearray(... .itemsize)` + `np.frombuffer` continue to work. + def send_posterior(self, dest, msg): """Send a message containing posterior data""" self.comm.Send( - buf=msg.tobytes(), + buf=[msg, MPI.BYTE], dest=dest, tag=messages.MPIMessageTag.POSTERIOR ) @@ -32,7 +40,7 @@ def send_velocity_position(self, dest, msg): velocity data""" self.comm.Send( - msg.tobytes(), + [msg, MPI.BYTE], dest=dest, tag=messages.MPIMessageTag.VEL_POS ) @@ -42,7 +50,7 @@ def send_dropped_spikes(self, dest, msg): spikes""" self.comm.Send( - msg.tobytes(), + [msg, MPI.BYTE], dest=dest, tag=messages.MPIMessageTag.DROPPED_SPIKES ) @@ -461,6 +469,22 @@ def __init__( self._init_timings() self._set_up_trodes() + # Pre-allocated scratch buffers reused on every LFP tick (~167 Hz at + # 30 kHz spike clock with a 180-sample bin). Allocating these per + # tick generates ~500+ short-lived numpy objects/sec, which is the + # kind of churn that triggers gen-2 GC pauses long enough to show up + # as the 100 ms tail-latency spikes the lab has seen. Reusing them + # in-place keeps the steady-state allocation rate near zero. + self._enc_cred_intervals = np.zeros( + self.p['cred_int_bufsize'], dtype=int + ) + self._enc_argmaxes = np.zeros( + self.p['cred_int_bufsize'], dtype=int + ) + self._spike_mask = np.zeros( + self._spike_buf.shape[0], dtype=bool + ) + def next_iter(self): """Run one iteration processing any available neural data""" @@ -808,32 +832,35 @@ def _process_lfp_timestamp(self, timestamp): """Process a new LFP timestamp by triggering an updated estimate of the posterior""" - # these are default values. if there are relevant spikes - # in the time bin of interest, these will be populated - # accordingly - enc_cred_intervals = np.zeros(self.p['cred_int_bufsize'], dtype=int) - enc_argmaxes = np.zeros(self.p['cred_int_bufsize'], dtype=int) + # Reuse persistent scratch buffers (allocated once in __init__) + # rather than allocating per-tick. Zeroing in place is ~free; the + # numpy/gc overhead of `np.zeros(N)` at ~167 Hz is not. + self._enc_cred_intervals.fill(0) + self._enc_argmaxes.fill(0) lb = int(timestamp - self.p['tbin_delay_samples'] - self.p['tbin_samples']) ub = int(timestamp - self.p['tbin_delay_samples']) + # Compute the bin-membership mask in place into a preallocated + # bool array. `np.logical_and(out=...)` avoids the temporary the + # default form allocates. spikes_in_bin_mask = np.logical_and( self._spike_buf[:, 0] >= lb, - self._spike_buf[:, 0] < ub + self._spike_buf[:, 0] < ub, + out=self._spike_mask, ) - if np.sum(spikes_in_bin_mask) > 0: + if np.any(spikes_in_bin_mask): # these spikes are being used. mark them with a 1 self._spike_buf[spikes_in_bin_mask, 4] = 1 - spikes_before = np.atleast_2d( - self._spike_buf[spikes_in_bin_mask] - ) + # Boolean indexing a 2D array with a 1D bool mask already + # returns 2D — np.atleast_2d here was a no-op that added an + # extra array wrapper per tick. Drop it. + spikes_before = self._spike_buf[spikes_in_bin_mask] unique_inds = self._get_unique(spikes_before[:, 0]) #NOTE(DS): to get rid of duplicated spikes - spikes_after = np.atleast_2d( - spikes_before[unique_inds] - ) + spikes_after = spikes_before[unique_inds] num_before = len(spikes_before) num_after = len(spikes_after) @@ -848,10 +875,13 @@ def _process_lfp_timestamp(self, timestamp): # main process will check for non-nan elements order = np.argsort(spikes_after[:, 0]) ordered_spikes = spikes_after[order] + cred_int_max = self.p['cred_int_max'] + cred_int_bufsize = self.p['cred_int_bufsize'] for ii, data in enumerate(ordered_spikes): - if data[3] <= self.p['cred_int_max']: - enc_cred_intervals[ii % self.p['cred_int_bufsize']] = data[1] - enc_argmaxes[ii % self.p['cred_int_bufsize']] = np.argmax(data[5:]) + if data[3] <= cred_int_max: + slot = ii % cred_int_bufsize + self._enc_cred_intervals[slot] = data[1] + self._enc_argmaxes[slot] = np.argmax(data[5:]) # Note: the decoder can automatically handle the no-spike case spikes_in_bin_count = num_after @@ -867,7 +897,7 @@ def _process_lfp_timestamp(self, timestamp): spikes_in_bin_count = 0 t0 = time.time_ns() posterior, likelihood = self._decoder.compute_posterior( - np.atleast_2d(self._spike_buf[spikes_in_bin_mask]) + self._spike_buf[spikes_in_bin_mask] ) t1 = time.time_ns() self._time_posterior(lb, ub, t0, t1) @@ -890,8 +920,8 @@ def _process_lfp_timestamp(self, timestamp): self._posterior_msg[0]['velocity'] = self._current_vel self._posterior_msg[0]['cred_int_post'] = cred_int_post self._posterior_msg[0]['cred_int_lk'] = cred_int_lk - self._posterior_msg[0]['enc_cred_intervals'] = enc_cred_intervals - self._posterior_msg[0]['enc_argmaxes'] = enc_argmaxes + self._posterior_msg[0]['enc_cred_intervals'] = self._enc_cred_intervals + self._posterior_msg[0]['enc_argmaxes'] = self._enc_argmaxes self._posterior_msg[0]['spike_count'] = spikes_in_bin_count self.send_interface.send_posterior( self._config['rank']['supervisor'][0], self._posterior_msg