From dfe92c89f44330c1b3606d064f7126d0dcbdbf39 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Sat, 7 Mar 2026 20:19:23 -0600 Subject: [PATCH 1/2] Slice epsilons with dead channels to match gains during agc. --- .../preprocessing/highpass_spatial_filter.py | 4 +++- .../tests/test_highpass_spatial_filter.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index c542725dd7..dfa049bc52 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -308,7 +308,9 @@ def agc(traces, window, epsilons): dead_channels = np.sum(gain, axis=0) == 0 - traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilons, gain[:, ~dead_channels]) + traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum( + epsilons[~dead_channels], gain[:, ~dead_channels] + ) return traces, gain diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 4aa014bbeb..2ad399b8db 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -103,6 +103,25 @@ def test_highpass_spatial_filter_synthetic_data(num_channels, ntr_pad, ntr_tap, assert raw_traces.shape == si_filtered.shape +def test_highpass_spatial_filter_with_dead_channels(): + """Regression test: AGC must handle dead (all-zero) channels without broadcast error. + + PR #4286 changed epsilon from a scalar to a per-channel array, but the agc() + function indexed gain with ~dead_channels without applying the same mask to + epsilons, causing a broadcast error when any channels had zero signal. + """ + num_channels = 32 + rec = generate_recording(num_channels=num_channels, durations=[0.5]) + # Materialize traces and zero out 3 channels to make them "dead" + traces = rec.get_traces().copy() + traces[:, [0, 15, 31]] = 0.0 + rec_with_dead = rec.save(format="memory") + rec_with_dead._recording_segments[0]._traces = traces + filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) + result = filtered.get_traces() + assert result.shape == traces.shape + + @pytest.mark.parametrize("dtype", [np.int16, np.float32, np.float64]) def test_dtype_stability(dtype): """ From 7c4bb68afb49a1984f0108206db9d69f995f8313 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 9 Mar 2026 15:57:20 +0000 Subject: [PATCH 2/2] fix: use NumpyRecording instead if save to memory --- .../preprocessing/highpass_spatial_filter.py | 2 +- .../preprocessing/tests/test_highpass_spatial_filter.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index dfa049bc52..f64e553980 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -131,7 +131,7 @@ def __init__( rms_values = recording.get_property("noise_level_rms_raw") else: random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs - rms_values = get_noise_levels(recording, method="rms", return_scaled=False, **random_slice_kwargs) + rms_values = get_noise_levels(recording, method="rms", return_in_uV=False, **random_slice_kwargs) # Pre-compute spatial filtering parameters butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn) diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 2ad399b8db..bfa4d3d9ae 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -6,7 +6,7 @@ import spikeinterface.core as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se -from spikeinterface.core import generate_recording +from spikeinterface.core import generate_recording, NumpyRecording import importlib.util ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -115,8 +115,10 @@ def test_highpass_spatial_filter_with_dead_channels(): # Materialize traces and zero out 3 channels to make them "dead" traces = rec.get_traces().copy() traces[:, [0, 15, 31]] = 0.0 - rec_with_dead = rec.save(format="memory") - rec_with_dead._recording_segments[0]._traces = traces + rec_with_dead = NumpyRecording( + traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids + ) + rec_with_dead.set_probe(rec.get_probe(), in_place=True) filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) result = filtered.get_traces() assert result.shape == traces.shape