1- from signal import signal
2-
3- from toolz import first
4- from torch .onnx .symbolic_opset11 import chunk
5-
61from spikeinterface import BaseRecording
72import numpy as np
83from spikeinterface .sortingcomponents .motion .motion_utils import make_2d_motion_histogram
@@ -62,9 +57,7 @@ def get_activity_histogram(
6257 peak_locations ,
6358 weight_with_amplitude = False ,
6459 direction = "y" ,
65- bin_s = (
66- bin_s if bin_s is not None else recording .get_duration (segment_index = 0 )
67- ), # TODO: doube cehck is this already scaling?
60+ bin_s = (bin_s if bin_s is not None else recording .get_duration (segment_index = 0 )),
6861 bin_um = None ,
6962 hist_margin_um = None ,
7063 spatial_bin_edges = spatial_bin_edges ,
@@ -95,15 +88,20 @@ def get_bin_centers(bin_edges):
9588
9689def estimate_chunk_size (scaled_activity_histogram ):
9790 """
98- Get an estimate of chunk size such that
99- the 80th percentile of the firing rate will be
100- estimated within 10% 90% of the time,
91+ Estimate a chunk size based on the firing rate. Intuitively, we
92+ want longer chunk size to better estimate low firing rates. The
93+ estimation computes a summary of the the firing rates for the session
94+ by taking the value 25% of the max of the activity histogram.
95+
96+ Then, the chunk size that will accurately estimate this firing rate
97+ within 90% accuracy, 90% of the time based on assumption of Poisson
98+ firing (based on CLT) is computed.
10199
102- I think a better way is to take the peaks above half width and find the min.
103- Or just to take the 50th percentile...? NO. Because all peaks might be similar heights
100+ Parameters
101+ ----------
104102
105- corrected based on assumption
106- of Poisson firing (based on CLT) .
103+ scaled_activity_histogram: np.ndarray
104+ The activity histogram scaled to firing rate in Hz .
107105
108106 TODO
109107 ----
@@ -162,7 +160,7 @@ def get_chunked_hist_supremum(chunked_session_histograms):
162160
163161 min_hist = np .min (chunked_session_histograms , axis = 0 )
164162
165- scaled_range = (max_hist - min_hist ) / max_hist # TODO: no idea if this is a good idea or not
163+ scaled_range = (max_hist - min_hist ) / ( max_hist + 1e-12 )
166164
167165 return max_hist , scaled_range
168166
@@ -201,28 +199,31 @@ def get_chunked_hist_eigenvector(chunked_session_histograms):
201199 """
202200 TODO: a little messy with the 2D stuff. Will probably deprecate anyway.
203201 """
204- if chunked_session_histograms .shape [0 ] == 1 : # TODO: handle elsewhere
202+ if chunked_session_histograms .shape [0 ] == 1 :
205203 return chunked_session_histograms .squeeze (), None
206204
207205 is_2d = chunked_session_histograms .ndim == 3
208206
209207 if is_2d :
210- num_hist , num_spat_bin , num_amp_bin = chunked_histograms .shape
208+ num_hist , num_spat_bin , num_amp_bin = chunked_session_histograms .shape
211209 chunked_session_histograms = np .reshape (chunked_session_histograms , (num_hist , num_spat_bin * num_amp_bin ))
212210
213211 A = chunked_session_histograms
214212 S = (1 / A .shape [0 ]) * A .T @ A
215213
216- U , S , Vh = np .linalg .svd (S ) # TODO: this is already symmetric PSD so use eig
214+ L , U = np .linalg .eigh (S )
217215
218- first_eigenvector = U [:, 0 ] * np .sqrt (S [ 0 ])
219- first_eigenvector = np .abs (first_eigenvector ) # sometimes the eigenvector can be negative
216+ first_eigenvector = U [:, - 1 ] * np .sqrt (L [ - 1 ])
217+ first_eigenvector = np .abs (first_eigenvector ) # sometimes the eigenvector is negative
220218
219+ # Project all vectors (histograms) onto the principal component,
220+ # then take the standard deviation in each dimension (over bins)
221221 v1 = first_eigenvector [:, np .newaxis ]
222- reconstruct = (A @ v1 ) @ v1 .T
223- v1_std = np .std (np .sqrt (reconstruct ), axis = 0 , ddof = 0 ) # TODO: double check sqrt works out
222+ projection_onto_v1 = (A @ v1 @ v1 .T ) / (v1 .T @ v1 )
224223
225- if is_2d :
224+ v1_std = np .std (projection_onto_v1 , axis = 0 )
225+
226+ if is_2d : # TODO: double check this
226227 first_eigenvector = np .reshape (first_eigenvector , (num_spat_bin , num_amp_bin ))
227228 v1_std = np .reshape (v1_std , (num_spat_bin , num_amp_bin ))
228229
@@ -423,7 +424,9 @@ def compute_histogram_crosscorrelation(
423424 windowed_histogram_i = session_histogram_list [i , :] * window
424425 windowed_histogram_j = session_histogram_list [j , :] * window
425426
426- xcorr = np .correlate (windowed_histogram_i , windowed_histogram_j , mode = "full" )
427+ xcorr = np .correlate (
428+ windowed_histogram_i , windowed_histogram_j , mode = "full"
429+ ) # TODO: add weight option.
427430
428431 if num_shifts_block :
429432 window_indices = np .arange (center_bin - num_shifts_block , center_bin + num_shifts_block )
@@ -435,6 +438,14 @@ def compute_histogram_crosscorrelation(
435438
436439 # Smooth the cross-correlations across the bins
437440 if smoothing_sigma_bin :
441+ breakpoint ()
442+ import matplotlib .pyplot as plt
443+
444+ plt .plot (xcorr_matrix [0 , :])
445+ X = gaussian_filter (xcorr_matrix , smoothing_sigma_bin , axes = 1 )
446+ plt .plot (X [0 , :])
447+ plt .show ()
448+
438449 xcorr_matrix = gaussian_filter (xcorr_matrix , smoothing_sigma_bin , axes = 1 )
439450
440451 # Smooth the cross-correlations across the windows
@@ -495,3 +506,79 @@ def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:
495506 cut_padded_array = padded_hist [abs_shift :] if shift >= 0 else padded_hist [:- abs_shift ]
496507
497508 return cut_padded_array
509+
510+
511+ def akima_interpolate_nonrigid_shifts (
512+ non_rigid_shifts : np .ndarray ,
513+ non_rigid_window_centers : np .ndarray ,
514+ spatial_bin_centers : np .ndarray ,
515+ ):
516+ """
517+ Perform Akima spline interpolation on a set of non-rigid shifts.
518+ The non-rigid shifts are per segment of the probe, each segment
519+ containing a number of channels. Interpolating these non-rigid
520+ shifts to the spatial bin centers gives a more accurate shift
521+ per channel.
522+
523+ Parameters
524+ ----------
525+ non_rigid_shifts : np.ndarray
526+ non_rigid_window_centers : np.ndarray
527+ spatial_bin_centers : np.ndarray
528+
529+ Returns
530+ -------
531+ interp_nonrigid_shifts : np.ndarray
532+ An array (length num_spatial_bins) of shifts
533+ interpolated from the non-rigid shifts.
534+
535+ TODO
536+ ----
537+ requires scipy 14
538+ """
539+ from scipy .interpolate import Akima1DInterpolator
540+
541+ x = non_rigid_window_centers
542+ xs = spatial_bin_centers
543+
544+ num_sessions = non_rigid_shifts .shape [0 ]
545+ num_bins = spatial_bin_centers .shape [0 ]
546+
547+ interp_nonrigid_shifts = np .zeros ((num_sessions , num_bins ))
548+ for ses_idx in range (num_sessions ):
549+
550+ y = non_rigid_shifts [ses_idx ]
551+ y_new = Akima1DInterpolator (x , y , method = "akima" , extrapolate = True )(xs )
552+ interp_nonrigid_shifts [ses_idx , :] = y_new
553+
554+ return interp_nonrigid_shifts
555+
556+
557+ def get_shifts_from_session_matrix (alignment_order : str , session_offsets_matrix : np .ndarray ):
558+ """
559+ Given a matrix of displacements between all sessions, find the
560+ shifts (one per session) to bring the sessions into alignment.
561+
562+ Parameters
563+ ----------
564+ alignment_order : "to_middle" or "to_session_X" where
565+ "N" is the number of the session to align to.
566+ session_offsets_matrix : np.ndarray
567+ The num_sessions x num_sessions symmetric matrix
568+ of displacements between all sessions, generated by
569+ `_compute_session_alignment()`.
570+
571+ Returns
572+ -------
573+ optimal_shift_indices : np.ndarray
574+ A 1 x num_sessions array of shifts to apply to
575+ each session in order to bring all sessions into
576+ alignment.
577+ """
578+ if alignment_order == "to_middle" :
579+ optimal_shift_indices = - np .mean (session_offsets_matrix , axis = 0 )
580+ else :
581+ ses_idx = int (alignment_order .split ("_" )[- 1 ]) - 1
582+ optimal_shift_indices = - session_offsets_matrix [ses_idx , :, :]
583+
584+ return optimal_shift_indices
0 commit comments