diff --git a/src/osekit/core_api/spectro_data.py b/src/osekit/core_api/spectro_data.py index be7be5eb..cfc7c106 100644 --- a/src/osekit/core_api/spectro_data.py +++ b/src/osekit/core_api/spectro_data.py @@ -82,6 +82,8 @@ def __init__( self._db_ref = db_ref self.v_lim = v_lim self.colormap = "viridis" if colormap is None else colormap + self.previous_data = None + self.next_data = None @staticmethod def get_default_ax() -> plt.Axes: @@ -249,11 +251,26 @@ def get_value(self) -> np.ndarray: padding="zeros", ) + sx = self._merge_with_previous(sx) + sx = self._remove_overlap_with_next(sx) + if self.sx_dtype is float: sx = abs(sx) ** 2 return sx + def _merge_with_previous(self, data: np.ndarray) -> np.ndarray: + if self.previous_data is None: + return data + olap = SpectroData.get_overlapped_bins(self.previous_data, self) + return np.hstack((olap, data[:, olap.shape[1] :])) + + def _remove_overlap_with_next(self, data: np.ndarray) -> np.ndarray: + if self.next_data is None: + return data + olap = SpectroData.get_overlapped_bins(self, self.next_data) + return data[:, : -olap.shape[1]] + def get_welch( self, nperseg: int | None = None, @@ -567,7 +584,7 @@ def split(self, nb_subdata: int = 2) -> list[SpectroData]: self.audio_data.split_frames(start_frame=a, stop_frame=b) for a, b in itertools.pairwise(split_frames) ] - return [ + sd_split = [ SpectroData.from_audio_data( data=ad, fft=self.fft, @@ -577,6 +594,12 @@ def split(self, nb_subdata: int = 2) -> list[SpectroData]: for ad in ad_split ] + for sd1, sd2 in itertools.pairwise(sd_split): + sd1.next_data = sd2 + sd2.previous_data = sd1 + + return sd_split + def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: if not all( np.array_equal(items[0].file.freq, i.file.freq) @@ -588,23 +611,51 @@ def _get_value_from_items(self, items: list[SpectroItem]) -> np.ndarray: if len({i.file.get_fft().delta_t for i in items if not i.is_empty}) > 1: raise ValueError("Items don't have the same time resolution.") - output = items[0].get_value(fft=self.fft, sx_dtype=self.sx_dtype) - for item in items[1:]: - p1_le = self.fft.lower_border_end[1] - self.fft.p_min - output = np.hstack( - ( - output[:, :-p1_le], - ( - output[:, -p1_le:] - + item.get_value(fft=self.fft, sx_dtype=self.sx_dtype)[ - :, - :p1_le, - ] - ), - item.get_value(fft=self.fft, sx_dtype=self.sx_dtype)[:, p1_le:], - ), - ) - return output + return np.hstack( + [item.get_value(fft=self.fft, sx_dtype=self.sx_dtype) for item in items], + ) + + @classmethod + def get_overlapped_bins(cls, sd1: SpectroData, sd2: SpectroData) -> np.ndarray: + """Compute the bins that overflow between the two spectro data. + + The idea is that if there is a SpectroData sd2 that follows sd1, + sd1.get_value() will return the bins up to the first overlapping bin, + and sd2 will return the bins from the first overlapping bin. + + Signal processing guys might want to burn my house to the ground for it, + but it seems to effectively resolve the issue we have with visible junction + between spectrogram zoomed parts. + + Parameters + ---------- + sd1: SpectroData + The spectro data that ends before sd2. + sd2: SpectroData + The spectro data that starts after sd1. + + Returns + ------- + np.ndarray: + The overlapped bins. + If there are p bins, sd1 and sd2 values should be concatenated as: + np.hstack(sd1[:,:-p], result, sd2[:,p:]) + + """ + fft = sd1.fft + sd1_ub = fft.upper_border_begin(sd1.audio_data.shape[0]) + sd1_bin_start = fft.nearest_k_p(k=sd1_ub[0], left=True) + sd2_lb = fft.lower_border_end + sd2_bin_stop = fft.nearest_k_p(k=sd2_lb[0], left=False) + + ad1 = sd1.audio_data.split_frames(start_frame=sd1_bin_start) + ad2 = sd2.audio_data.split_frames(stop_frame=sd2_bin_stop) + + sd_part1 = SpectroData.from_audio_data(ad1, fft=fft).get_value() + sd_part2 = SpectroData.from_audio_data(ad2, fft=fft).get_value() + + p1_le = fft.lower_border_end[1] - fft.p_min + return sd_part1[:, -p1_le:] + sd_part2[:, :p1_le] @classmethod def from_files( diff --git a/src/osekit/core_api/spectro_file.py b/src/osekit/core_api/spectro_file.py index d705e0d1..5654fa90 100644 --- a/src/osekit/core_api/spectro_file.py +++ b/src/osekit/core_api/spectro_file.py @@ -122,9 +122,12 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: start_bin = ( next( - idx - for idx, t in enumerate(time) - if self.begin + Timedelta(seconds=t) > start + ( + idx + for idx, t in enumerate(time) + if self.begin + Timedelta(seconds=t) > start + ), + 1, ) - 1 ) @@ -132,9 +135,12 @@ def read(self, start: Timestamp, stop: Timestamp) -> np.ndarray: stop_bin = ( next( - idx - for idx, t in list(enumerate(time))[::-1] - if self.begin + Timedelta(seconds=t) < stop + ( + idx + for idx, t in list(enumerate(time))[::-1] + if self.begin + Timedelta(seconds=t) < stop + ), + len(time) - 2, ) + 1 ) diff --git a/tests/test_spectro.py b/tests/test_spectro.py index 90e71f8b..062fb6a6 100644 --- a/tests/test_spectro.py +++ b/tests/test_spectro.py @@ -269,14 +269,14 @@ def test_spectro_parameters_in_npz_files( pytest.param( { "duration": 6, - "sample_rate": 1_024, + "sample_rate": 28_000, "nb_files": 1, "date_begin": pd.Timestamp("2024-01-01 12:00:00"), }, None, None, 6, - ShortTimeFFT(hamming(1_024), 100, 1_024), + ShortTimeFFT(hamming(1_024), 100, 28_000), id="6_seconds_split_in_6_with_overlap", ), pytest.param( @@ -289,7 +289,7 @@ def test_spectro_parameters_in_npz_files( Instrument(end_to_end_db=150.0), None, 6, - ShortTimeFFT(hamming(1_024), 100, 1_024), + ShortTimeFFT(hamming(1_024), 1_024, 1_024), id="audio_data_with_instrument", ), pytest.param( @@ -302,7 +302,7 @@ def test_spectro_parameters_in_npz_files( None, Normalization.ZSCORE, 6, - ShortTimeFFT(hamming(1_024), 100, 1_024), + ShortTimeFFT(hamming(1_024), 1_024, 1_024), id="audio_data_with_normalization", ), ], @@ -328,7 +328,7 @@ def test_spectrogram_from_npz_files( sd_split = sd.split(nb_chunks) - import soundfile as sf + import soundfile as sf # noqa: PLC0415 for spectro in sd_split: spectro.write(tmp_path / "output") @@ -985,10 +985,14 @@ def test_spectrodata_split( colormap=colormap, ) sd_parts = sd.split(parts) - for sd_part in sd_parts: + for idx, sd_part in enumerate(sd_parts): assert sd_part.fft is sd.fft assert sd_part.v_lim == sd.v_lim assert sd_part.colormap == sd.colormap + if idx > 0: + assert sd_part.previous_data == sd_parts[idx - 1] + if idx < len(sd_parts) - 1: + assert sd_part.next_data == sd_parts[idx + 1] assert sd_parts[0].begin == sd.begin assert sd_parts[-1].end == sd.end