Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def make_dataset(job_kwargs={}):
contact_shape_params={"radius": 6},
),
generate_sorting_kwargs=dict(firing_rates=6.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,7 +2286,7 @@ def generate_ground_truth_recording(
upsample_factor=None,
upsample_vector=None,
generate_sorting_kwargs=dict(firing_rates=15, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
noise_kwargs=dict(noise_levels=5.0),
generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=5.0, maximum_z=50.0, minimum_distance=20),
generate_templates_kwargs=None,
dtype="float32",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_sorting_analyzer(cache_folder, format="memory", sparse=True):
alpha=(200.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=5.0),
seed=2406,
)
if format == "memory":
Expand Down
2 changes: 0 additions & 2 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def test_noise_generator_several_noise_levels():
dtype="float32",
seed=32,
noise_levels=1,
strategy="on_the_fly",
noise_block_size=20000,
)
assert np.all(np.abs(get_noise_levels(rec1) - 1) < 0.1)
Expand All @@ -232,7 +231,6 @@ def test_noise_generator_several_noise_levels():
dtype="float32",
seed=32,
noise_levels=[0, 1, 2, 3],
strategy="on_the_fly",
noise_block_size=20000,
)
assert np.all(np.abs(get_noise_levels(rec2) - np.arange(4)) < 0.1)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_random_spikes_selection():
num_channels=10,
num_units=5,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)
max_spikes_per_unit = 12
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_dataset():
num_channels=10,
num_units=5,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_sparsity.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def get_dataset():
num_channels=10,
num_units=5,
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=1.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=1.0),
seed=2205,
)
recording.set_property("group", ["a"] * 5 + ["b"] * 5)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_sorting_analyzer():
sampling_frequency=10_000.0,
num_channels=4,
num_units=10,
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)
recording.annotate(is_filtered=True)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/tests/test_waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_dataset():
num_channels=4,
num_units=7,
generate_sorting_kwargs=dict(firing_rates=5.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=1.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=1.0),
seed=2205,
)
return recording, sorting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_dataset():
alpha=(100.0, 500.0),
)
),
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
noise_kwargs=dict(noise_levels=5.0),
seed=2406,
)
return recording, sorting
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/curation/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def make_sorting_analyzer(sparse=True, num_units=5, durations=[300.0]):
num_channels=4,
num_units=num_units,
generate_sorting_kwargs=dict(firing_rates=20.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/exporters/tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def make_sorting_analyzer(sparse=True, with_group=False):
contact_shape_params={"radius": 6},
),
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_levels=5.0, strategy="on_the_fly"),
noise_kwargs=dict(noise_levels=5.0),
seed=2205,
)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/toy_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def toy_example(
ms_after=ms_after,
dtype="float32",
seed=seed,
noise_kwargs=dict(noise_levels=10.0, strategy="on_the_fly"),
noise_kwargs=dict(noise_levels=10.0),
)

return recording, sorting
119 changes: 76 additions & 43 deletions src/spikeinterface/generation/noise_tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from typing import Literal

from spikeinterface.core import BaseRecording, BaseRecordingSegment
from spikeinterface.core.generate import _ensure_seed
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessorSegment
from spikeinterface.core.recording_tools import get_chunk_with_margin


class NoiseGeneratorRecording(BaseRecording):
Expand All @@ -29,16 +30,18 @@ class NoiseGeneratorRecording(BaseRecording):
Std of the white noise (if an array, defined by per channels)
cov_matrix : np.ndarray | None, default: None
The covariance matrix of the noise
spectral_density : np.ndarray | None, default: None
The spectral density of the noise, as you could estimate from an array of snippets with shape
`(n_snippets, spectral_snippet_length)` by the following method (Welch's method):

```python
periodogram = rfft(snippets, n=next_fast_len(snippets.shape[1]), norm="ortho")
spectral_density = np.sqrt((periodogram * periodogram.conj()).mean(axis=0))
```
dtype : np.dtype | str | None, default: "float32"
The dtype of the recording. Note that only np.float32 and np.float64 are supported.
seed : int | None, default: None
The seed for np.random.default_rng.
strategy : "tile_pregenerated" | "on_the_fly", default: "tile_pregenerated"
The strategy of generating noise chunk:
* "tile_pregenerated": pregenerate a noise chunk of noise_block_size sample and repeat it
very fast and cusume only one noise block.
* "on_the_fly": generate on the fly a new noise block by combining seed + noise block index
no memory preallocation but a bit more computaion (random)
noise_block_size : int, default: 30000
Size in sample of noise block.

Expand All @@ -55,17 +58,16 @@ def __init__(
durations: list[float],
noise_levels: float | np.ndarray = 1.0,
cov_matrix: np.ndarray | None = None,
spectral_density: np.ndarray | None = None,
dtype: np.dtype | str | None = "float32",
seed: int | None = None,
strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated",
noise_block_size: int = 30000,
):

channel_ids = [str(index) for index in np.arange(num_channels)]
dtype = np.dtype(dtype).name # Cast to string for serialization
if dtype not in ("float32", "float64"):
raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}")
assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'"

if np.isscalar(noise_levels):
noise_levels = np.ones((1, num_channels)) * noise_levels
Expand Down Expand Up @@ -103,8 +105,9 @@ def __init__(
cov_matrix,
dtype,
segments_seeds[i],
strategy,
)
if spectral_density is not None:
rec_segment = AddTemporalCorrelationsSegment(rec_segment, spectral_density)
self.add_recording_segment(rec_segment)

self._kwargs = {
Expand All @@ -115,7 +118,6 @@ def __init__(
"cov_matrix": cov_matrix,
"dtype": dtype,
"seed": seed,
"strategy": strategy,
"noise_block_size": noise_block_size,
}

Expand All @@ -131,7 +133,6 @@ def __init__(
cov_matrix,
dtype,
seed,
strategy,
):
assert seed is not None

Expand All @@ -144,23 +145,6 @@ def __init__(
self.cov_matrix = cov_matrix
self.dtype = dtype
self.seed = seed
self.strategy = strategy

if self.strategy == "tile_pregenerated":
rng = np.random.default_rng(seed=self.seed)

if self.cov_matrix is None:
self.noise_block = (
rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype)
* noise_levels
)
else:
self.noise_block = rng.multivariate_normal(
np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size
)

elif self.strategy == "on_the_fly":
pass

def get_num_samples(self) -> int:
return self.num_samples
Expand All @@ -169,7 +153,7 @@ def get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
channel_indices: list | None = None,
channel_indices: list | np.ndarray | tuple | None = None,
) -> np.ndarray:

if start_frame is None:
Expand All @@ -188,18 +172,15 @@ def get_traces(

pos = 0
for block_index in range(first_block_index, last_block_index + 1):
if self.strategy == "tile_pregenerated":
noise_block = self.noise_block
elif self.strategy == "on_the_fly":
rng = np.random.default_rng(seed=(self.seed, block_index))
if self.cov_matrix is None:
noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype)
else:
noise_block = rng.multivariate_normal(
np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size
)
rng = np.random.default_rng(seed=(self.seed, block_index))
if self.cov_matrix is None:
noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype)
else:
noise_block = rng.multivariate_normal(
np.zeros(self.num_channels), self.cov_matrix, size=self.noise_block_size
)

noise_block *= self.noise_levels
noise_block *= self.noise_levels

if block_index == first_block_index:
if first_block_index != last_block_index:
Expand All @@ -222,13 +203,57 @@ def get_traces(
return traces


class AddTemporalCorrelationsSegment(BasePreprocessorSegment):
def __init__(self, parent_recording_segment, spectral_density: np.ndarray):
super().__init__(parent_recording_segment)
assert spectral_density.ndim == 1
self.spectral_density = spectral_density
self.margin = spectral_density.shape[0] - 1
self.block_len = 2 * spectral_density.shape[0] - 1
self.kernel = np.fft.fftshift(np.fft.irfft(spectral_density, n=self.block_len))

def get_traces(
self,
start_frame: int | None = None,
end_frame: int | None = None,
channel_indices: list | np.ndarray | tuple | None = None,
):
from scipy.signal import convolve

if start_frame is None:
start_frame = 0
if end_frame is None:
end_frame = self.get_num_samples()

traces, *_ = get_chunk_with_margin(
self.parent_recording_segment,
start_frame=start_frame,
end_frame=end_frame,
channel_indices=channel_indices,
margin=self.margin,
add_reflect_padding=True,
)
# need to use "direct", or else output differs numerically when start_frame, end_frame change
# that's because the FFT method would FFT the traces, and there would be slight numerical differences
traces = convolve(traces.T, self.kernel[None], mode="valid", method="direct").T
assert traces.shape[0] == end_frame - start_frame
return traces


noise_generator_recording = define_function_from_class(
source_class=NoiseGeneratorRecording, name="noise_generator_recording"
)


def generate_noise(
probe, sampling_frequency, durations, dtype="float32", noise_levels=15.0, spatial_decay=None, seed=None
probe,
sampling_frequency,
durations,
dtype="float32",
noise_levels=15.0,
spatial_decay=None,
spectral_density=None,
seed=None,
):
"""
Generate a noise recording.
Expand All @@ -249,6 +274,14 @@ def generate_noise(
If tuple, then this represent the range.
spatial_decay : float | None, default: None
If not None, the spatial decay of the noise used to generate the noise covariance matrix.
spectral_density : np.ndarray | None, default: None
The spectral density of the noise, as you could estimate from an array of snippets with shape
`(n_snippets, spectral_snippet_length)` by the following method (Welch's method):

```python
periodogram = rfft(snippets, n=next_fast_len(snippets.shape[1]), norm="ortho")
spectral_density = np.sqrt((periodogram * periodogram.conj()).mean(axis=0))
```
seed : int | None, default: None
The seed for random generator.

Expand Down Expand Up @@ -284,9 +317,9 @@ def generate_noise(
sampling_frequency=sampling_frequency,
durations=durations,
dtype=dtype,
strategy="on_the_fly",
noise_levels=noise_levels,
cov_matrix=cov_matrix,
spectral_density=spectral_density,
seed=seed,
)

Expand Down
Loading
Loading