Skip to content

Commit 2c1ebca

Browse files
committed
Tidying up, begin fixing alignment alg.
1 parent 396434e commit 2c1ebca

2 files changed

Lines changed: 119 additions & 106 deletions

File tree

src/spikeinterface/preprocessing/inter_session_alignment/alignment_utils.py

Lines changed: 112 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,3 @@
1-
from signal import signal
2-
3-
from toolz import first
4-
from torch.onnx.symbolic_opset11 import chunk
5-
61
from spikeinterface import BaseRecording
72
import numpy as np
83
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram
@@ -62,9 +57,7 @@ def get_activity_histogram(
6257
peak_locations,
6358
weight_with_amplitude=False,
6459
direction="y",
65-
bin_s=(
66-
bin_s if bin_s is not None else recording.get_duration(segment_index=0)
67-
), # TODO: doube cehck is this already scaling?
60+
bin_s=(bin_s if bin_s is not None else recording.get_duration(segment_index=0)),
6861
bin_um=None,
6962
hist_margin_um=None,
7063
spatial_bin_edges=spatial_bin_edges,
@@ -95,15 +88,20 @@ def get_bin_centers(bin_edges):
9588

9689
def estimate_chunk_size(scaled_activity_histogram):
9790
"""
98-
Get an estimate of chunk size such that
99-
the 80th percentile of the firing rate will be
100-
estimated within 10% 90% of the time,
91+
Estimate a chunk size based on the firing rate. Intuitively, we
92+
want longer chunk size to better estimate low firing rates. The
93+
estimation computes a summary of the the firing rates for the session
94+
by taking the value 25% of the max of the activity histogram.
95+
96+
Then, the chunk size that will accurately estimate this firing rate
97+
within 90% accuracy, 90% of the time based on assumption of Poisson
98+
firing (based on CLT) is computed.
10199
102-
I think a better way is to take the peaks above half width and find the min.
103-
Or just to take the 50th percentile...? NO. Because all peaks might be similar heights
100+
Parameters
101+
----------
104102
105-
corrected based on assumption
106-
of Poisson firing (based on CLT).
103+
scaled_activity_histogram: np.ndarray
104+
The activity histogram scaled to firing rate in Hz.
107105
108106
TODO
109107
----
@@ -162,7 +160,7 @@ def get_chunked_hist_supremum(chunked_session_histograms):
162160

163161
min_hist = np.min(chunked_session_histograms, axis=0)
164162

165-
scaled_range = (max_hist - min_hist) / max_hist # TODO: no idea if this is a good idea or not
163+
scaled_range = (max_hist - min_hist) / (max_hist + 1e-12)
166164

167165
return max_hist, scaled_range
168166

@@ -201,28 +199,31 @@ def get_chunked_hist_eigenvector(chunked_session_histograms):
201199
"""
202200
TODO: a little messy with the 2D stuff. Will probably deprecate anyway.
203201
"""
204-
if chunked_session_histograms.shape[0] == 1: # TODO: handle elsewhere
202+
if chunked_session_histograms.shape[0] == 1:
205203
return chunked_session_histograms.squeeze(), None
206204

207205
is_2d = chunked_session_histograms.ndim == 3
208206

209207
if is_2d:
210-
num_hist, num_spat_bin, num_amp_bin = chunked_histograms.shape
208+
num_hist, num_spat_bin, num_amp_bin = chunked_session_histograms.shape
211209
chunked_session_histograms = np.reshape(chunked_session_histograms, (num_hist, num_spat_bin * num_amp_bin))
212210

213211
A = chunked_session_histograms
214212
S = (1 / A.shape[0]) * A.T @ A
215213

216-
U, S, Vh = np.linalg.svd(S) # TODO: this is already symmetric PSD so use eig
214+
L, U = np.linalg.eigh(S)
217215

218-
first_eigenvector = U[:, 0] * np.sqrt(S[0])
219-
first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector can be negative
216+
first_eigenvector = U[:, -1] * np.sqrt(L[-1])
217+
first_eigenvector = np.abs(first_eigenvector) # sometimes the eigenvector is negative
220218

219+
# Project all vectors (histograms) onto the principal component,
220+
# then take the standard deviation in each dimension (over bins)
221221
v1 = first_eigenvector[:, np.newaxis]
222-
reconstruct = (A @ v1) @ v1.T
223-
v1_std = np.std(np.sqrt(reconstruct), axis=0, ddof=0) # TODO: double check sqrt works out
222+
projection_onto_v1 = (A @ v1 @ v1.T) / (v1.T @ v1)
224223

225-
if is_2d:
224+
v1_std = np.std(projection_onto_v1, axis=0)
225+
226+
if is_2d: # TODO: double check this
226227
first_eigenvector = np.reshape(first_eigenvector, (num_spat_bin, num_amp_bin))
227228
v1_std = np.reshape(v1_std, (num_spat_bin, num_amp_bin))
228229

@@ -423,7 +424,9 @@ def compute_histogram_crosscorrelation(
423424
windowed_histogram_i = session_histogram_list[i, :] * window
424425
windowed_histogram_j = session_histogram_list[j, :] * window
425426

426-
xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full")
427+
xcorr = np.correlate(
428+
windowed_histogram_i, windowed_histogram_j, mode="full"
429+
) # TODO: add weight option.
427430

428431
if num_shifts_block:
429432
window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block)
@@ -435,6 +438,14 @@ def compute_histogram_crosscorrelation(
435438

436439
# Smooth the cross-correlations across the bins
437440
if smoothing_sigma_bin:
441+
breakpoint()
442+
import matplotlib.pyplot as plt
443+
444+
plt.plot(xcorr_matrix[0, :])
445+
X = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)
446+
plt.plot(X[0, :])
447+
plt.show()
448+
438449
xcorr_matrix = gaussian_filter(xcorr_matrix, smoothing_sigma_bin, axes=1)
439450

440451
# Smooth the cross-correlations across the windows
@@ -495,3 +506,79 @@ def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:
495506
cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]
496507

497508
return cut_padded_array
509+
510+
511+
def akima_interpolate_nonrigid_shifts(
512+
non_rigid_shifts: np.ndarray,
513+
non_rigid_window_centers: np.ndarray,
514+
spatial_bin_centers: np.ndarray,
515+
):
516+
"""
517+
Perform Akima spline interpolation on a set of non-rigid shifts.
518+
The non-rigid shifts are per segment of the probe, each segment
519+
containing a number of channels. Interpolating these non-rigid
520+
shifts to the spatial bin centers gives a more accurate shift
521+
per channel.
522+
523+
Parameters
524+
----------
525+
non_rigid_shifts : np.ndarray
526+
non_rigid_window_centers : np.ndarray
527+
spatial_bin_centers : np.ndarray
528+
529+
Returns
530+
-------
531+
interp_nonrigid_shifts : np.ndarray
532+
An array (length num_spatial_bins) of shifts
533+
interpolated from the non-rigid shifts.
534+
535+
TODO
536+
----
537+
requires scipy 14
538+
"""
539+
from scipy.interpolate import Akima1DInterpolator
540+
541+
x = non_rigid_window_centers
542+
xs = spatial_bin_centers
543+
544+
num_sessions = non_rigid_shifts.shape[0]
545+
num_bins = spatial_bin_centers.shape[0]
546+
547+
interp_nonrigid_shifts = np.zeros((num_sessions, num_bins))
548+
for ses_idx in range(num_sessions):
549+
550+
y = non_rigid_shifts[ses_idx]
551+
y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs)
552+
interp_nonrigid_shifts[ses_idx, :] = y_new
553+
554+
return interp_nonrigid_shifts
555+
556+
557+
def get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray):
558+
"""
559+
Given a matrix of displacements between all sessions, find the
560+
shifts (one per session) to bring the sessions into alignment.
561+
562+
Parameters
563+
----------
564+
alignment_order : "to_middle" or "to_session_X" where
565+
"N" is the number of the session to align to.
566+
session_offsets_matrix : np.ndarray
567+
The num_sessions x num_sessions symmetric matrix
568+
of displacements between all sessions, generated by
569+
`_compute_session_alignment()`.
570+
571+
Returns
572+
-------
573+
optimal_shift_indices : np.ndarray
574+
A 1 x num_sessions array of shifts to apply to
575+
each session in order to bring all sessions into
576+
alignment.
577+
"""
578+
if alignment_order == "to_middle":
579+
optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0)
580+
else:
581+
ses_idx = int(alignment_order.split("_")[-1]) - 1
582+
optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :]
583+
584+
return optimal_shift_indices

src/spikeinterface/preprocessing/inter_session_alignment/session_alignment.py

Lines changed: 7 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def align_sessions(
211211
interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs)
212212

213213
# Ensure list lengths match and all channel locations are the same across recordings.
214-
_check_align_sesssions_inputs(
214+
_check_align_sessions_inputs(
215215
recordings_list, peaks_list, peak_locations_list, alignment_order, estimate_histogram_kwargs
216216
)
217217

@@ -894,11 +894,11 @@ def _compute_session_alignment(
894894
nonrigid_session_offsets_matrix = alignment_utils.compute_histogram_crosscorrelation(
895895
shifted_histograms, non_rigid_windows, **compute_alignment_kwargs
896896
)
897-
non_rigid_shifts = _get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
897+
non_rigid_shifts = alignment_utils.get_shifts_from_session_matrix(alignment_order, nonrigid_session_offsets_matrix)
898898

899899
# Akima interpolate the nonrigid bins if required.
900900
if akima_interp_nonrigid:
901-
interp_nonrigid_shifts = _akima_interpolate_nonrigid_shifts(
901+
interp_nonrigid_shifts = alignment_utils.akima_interpolate_nonrigid_shifts(
902902
non_rigid_shifts, non_rigid_window_centers, spatial_bin_centers
903903
)
904904
shifts = rigid_shifts + interp_nonrigid_shifts
@@ -944,83 +944,9 @@ def _estimate_rigid_alignment(
944944
rigid_window,
945945
**compute_alignment_kwargs,
946946
)
947-
optimal_shift_indices = _get_shifts_from_session_matrix(alignment_order, rigid_session_offsets_matrix)
948-
949-
return optimal_shift_indices
950-
951-
952-
def _akima_interpolate_nonrigid_shifts(
953-
non_rigid_shifts: np.ndarray,
954-
non_rigid_window_centers: np.ndarray,
955-
spatial_bin_centers: np.ndarray,
956-
):
957-
"""
958-
Perform Akima spline interpolation on a set of non-rigid shifts.
959-
The non-rigid shifts are per segment of the probe, each segment
960-
containing a number of channels. Interpolating these non-rigid
961-
shifts to the spatial bin centers gives a more accurate shift
962-
per channel.
963-
964-
Parameters
965-
----------
966-
non_rigid_shifts : np.ndarray
967-
non_rigid_window_centers : np.ndarray
968-
spatial_bin_centers : np.ndarray
969-
970-
Returns
971-
-------
972-
interp_nonrigid_shifts : np.ndarray
973-
An array (length num_spatial_bins) of shifts
974-
interpolated from the non-rigid shifts.
975-
976-
TODO
977-
----
978-
requires scipy 14
979-
"""
980-
from scipy.interpolate import Akima1DInterpolator
981-
982-
x = non_rigid_window_centers
983-
xs = spatial_bin_centers
984-
985-
num_sessions = non_rigid_shifts.shape[0]
986-
num_bins = spatial_bin_centers.shape[0]
987-
988-
interp_nonrigid_shifts = np.zeros((num_sessions, num_bins))
989-
for ses_idx in range(num_sessions):
990-
991-
y = non_rigid_shifts[ses_idx]
992-
y_new = Akima1DInterpolator(x, y, method="akima", extrapolate=True)(xs)
993-
interp_nonrigid_shifts[ses_idx, :] = y_new
994-
995-
return interp_nonrigid_shifts
996-
997-
998-
def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix: np.ndarray):
999-
"""
1000-
Given a matrix of displacements between all sessions, find the
1001-
shifts (one per session) to bring the sessions into alignment.
1002-
1003-
Parameters
1004-
----------
1005-
alignment_order : "to_middle" or "to_session_X" where
1006-
"N" is the number of the session to align to.
1007-
session_offsets_matrix : np.ndarray
1008-
The num_sessions x num_sessions symmetric matrix
1009-
of displacements between all sessions, generated by
1010-
`_compute_session_alignment()`.
1011-
1012-
Returns
1013-
-------
1014-
optimal_shift_indices : np.ndarray
1015-
A 1 x num_sessions array of shifts to apply to
1016-
each session in order to bring all sessions into
1017-
alignment.
1018-
"""
1019-
if alignment_order == "to_middle":
1020-
optimal_shift_indices = -np.mean(session_offsets_matrix, axis=0)
1021-
else:
1022-
ses_idx = int(alignment_order.split("_")[-1]) - 1
1023-
optimal_shift_indices = -session_offsets_matrix[ses_idx, :, :]
947+
optimal_shift_indices = alignment_utils.get_shifts_from_session_matrix(
948+
alignment_order, rigid_session_offsets_matrix
949+
)
1024950

1025951
return optimal_shift_indices
1026952

@@ -1030,7 +956,7 @@ def _get_shifts_from_session_matrix(alignment_order: str, session_offsets_matrix
1030956
# -----------------------------------------------------------------------------
1031957

1032958

1033-
def _check_align_sesssions_inputs(
959+
def _check_align_sessions_inputs(
1034960
recordings_list: list[BaseRecording],
1035961
peaks_list: list[np.ndarray],
1036962
peak_locations_list: list[np.ndarray],

0 commit comments

Comments
 (0)