diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index c542725dd7..f64e553980 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -131,7 +131,7 @@ def __init__( rms_values = recording.get_property("noise_level_rms_raw") else: random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs - rms_values = get_noise_levels(recording, method="rms", return_scaled=False, **random_slice_kwargs) + rms_values = get_noise_levels(recording, method="rms", return_in_uV=False, **random_slice_kwargs) # Pre-compute spatial filtering parameters butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn) @@ -308,7 +308,9 @@ def agc(traces, window, epsilons): dead_channels = np.sum(gain, axis=0) == 0 - traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilons, gain[:, ~dead_channels]) + traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum( + epsilons[~dead_channels], gain[:, ~dead_channels] + ) return traces, gain diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 4aa014bbeb..bfa4d3d9ae 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -6,7 +6,7 @@ import spikeinterface.core as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se -from spikeinterface.core import generate_recording +from spikeinterface.core import generate_recording, NumpyRecording import importlib.util ON_GITHUB = bool(os.getenv("GITHUB_ACTIONS")) @@ -103,6 +103,27 @@ def test_highpass_spatial_filter_synthetic_data(num_channels, ntr_pad, ntr_tap, assert raw_traces.shape == si_filtered.shape +def test_highpass_spatial_filter_with_dead_channels(): + """Regression test: AGC must handle dead (all-zero) channels without broadcast error. + + PR #4286 changed epsilon from a scalar to a per-channel array, but the agc() + function indexed gain with ~dead_channels without applying the same mask to + epsilons, causing a broadcast error when any channels had zero signal. + """ + num_channels = 32 + rec = generate_recording(num_channels=num_channels, durations=[0.5]) + # Materialize traces and zero out 3 channels to make them "dead" + traces = rec.get_traces().copy() + traces[:, [0, 15, 31]] = 0.0 + rec_with_dead = NumpyRecording( + traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids + ) + rec_with_dead.set_probe(rec.get_probe(), in_place=True) + filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) + result = filtered.get_traces() + assert result.shape == traces.shape + + @pytest.mark.parametrize("dtype", [np.int16, np.float32, np.float64]) def test_dtype_stability(dtype): """