diff --git a/src/post_processing/dataclass/detection_filter.py b/src/post_processing/dataclass/detection_filter.py index 4378192..c7938f4 100644 --- a/src/post_processing/dataclass/detection_filter.py +++ b/src/post_processing/dataclass/detection_filter.py @@ -42,6 +42,7 @@ class DetectionFilter: f_max: float | None = None score: float | None = None box: bool = False + filename_format: str = None @classmethod def from_yaml( @@ -86,7 +87,6 @@ def from_dict( filters = [] for detection_file, filters_dict in parameters.items(): df_preview = read_dataframe(Path(detection_file), nrows=5) - filters_dict["timebin_origin"] = Timedelta( max(df_preview["end_time"]), "s", diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index a656855..036aca5 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -196,7 +196,7 @@ def get_dataset(df: DataFrame) -> list[str]: def get_canonical_tz(tz): """Return timezone of object as a pytz timezone.""" if isinstance(tz, datetime.timezone): - if tz == datetime.timezone.utc: + if tz == datetime.UTC: return pytz.utc offset_minutes = int(tz.utcoffset(None).total_seconds() / 60) return pytz.FixedOffset(offset_minutes) @@ -204,13 +204,24 @@ def get_canonical_tz(tz): return pytz.timezone(tz.zone) if hasattr(tz, "key"): return pytz.timezone(tz.key) - else: - msg = f"Unknown timezone: {tz}" - raise TypeError(msg) + msg = f"Unknown timezone: {tz}" + raise TypeError(msg) def get_timezone(df: DataFrame): - """Return timezone(s) from DataFrame.""" + """Return timezone(s) from APLOSE DataFrame. + + Parameters + ---------- + df: DataFrame + APLOSE result Dataframe + + Returns + ------- + tzoffset: list[tzoffset] + list of timezones + + """ timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]} if len(timezones) == 1: @@ -218,10 +229,29 @@ def get_timezone(df: DataFrame): return list(timezones) +def check_timestamp(df: DataFrame, timestamp_audio: list[Timestamp]) -> None: + """Check if provided timestamp_audio list is correctly formated. + + Parameters + ---------- + df: DataFrame APLOSE results Dataframe. + timestamp_audio: A list of timestamps. Each timestamp is the start datetime of the + corresponding audio file for each detection in df. + + """ + if timestamp_audio is None: + msg = "`timestamp_wav` is empty" + raise ValueError(msg) + if len(timestamp_audio) != len(df): + msg = "`timestamp_wav` is not the same length as `df`" + raise ValueError(msg) + + def reshape_timebin( df: DataFrame, + *, timebin_new: Timedelta | None, - timestamp: list[Timestamp] | None = None, + timestamp_audio: list[Timestamp] | None = None, ) -> DataFrame: """Reshape an APLOSE result DataFrame according to a new time bin. @@ -231,8 +261,9 @@ def reshape_timebin( An APLOSE result DataFrame. timebin_new: Timedelta The size of the new time bin. - timestamp: list[Timestamp] - A list of Timestamp objects. + timestamp_audio: list[Timestamp] + A list of Timestamp objects corresponding to the shape + in which the data should be reshaped. Returns ------- @@ -247,14 +278,20 @@ def reshape_timebin( if not timebin_new: return df + check_timestamp(df, timestamp_audio) + annotators = get_annotators(df) labels = get_labels(df) max_freq = get_max_freq(df) dataset = get_dataset(df) if isinstance(get_timezone(df), list): - df["start_datetime"] = [to_datetime(elem, utc=True) for elem in df["start_datetime"]] - df["end_datetime"] = [to_datetime(elem, utc=True) for elem in df["end_datetime"]] + df["start_datetime"] = [to_datetime(elem, utc=True) + for elem in df["start_datetime"] + ] + df["end_datetime"] = [to_datetime(elem, utc=True) + for elem in df["end_datetime"] + ] results = [] for ant in annotators: @@ -264,13 +301,13 @@ def reshape_timebin( if df_1annot_1label.empty: continue - if timestamp is not None: + if timestamp_audio is not None: # I do not remember if this is a regular case or not # might need to be deleted - origin_timebin = timestamp[1] - timestamp[0] - step = int(timebin_new / origin_timebin) - time_vector = timestamp[0::step] - else: + #origin_timebin = timestamp_audio[1] - timestamp_audio[0] + #step = int(timebin_new / origin_timebin) + #time_vector = timestamp_audio[0::step] + #else: t1 = min(df_1annot_1label["start_datetime"]).floor(timebin_new) t2 = max(df_1annot_1label["end_datetime"]).ceil(timebin_new) time_vector = date_range(start=t1, end=t2, freq=timebin_new) @@ -280,14 +317,19 @@ def reshape_timebin( filenames = df_1annot_1label["filename"].to_list() # filename_vector - filename_vector = [ - filenames[ - bisect.bisect_left(ts_detect_beg, ts) - (ts not in ts_detect_beg) - ] - if bisect.bisect_left(ts_detect_beg, ts) > 0 - else filenames[0] - for ts in time_vector - ] + filename_vector = [] + for ts in time_vector: + if (bisect.bisect_left(ts_detect_beg, ts) > 0 and + bisect.bisect_left(ts_detect_beg, ts) != len(ts_detect_beg)): + idx = bisect.bisect_left(ts_detect_beg, ts) + filename_vector.append( + filenames[idx] if timestamp_audio[idx] <= ts else + filenames[idx - 1], + ) + elif bisect.bisect_left(ts_detect_beg, ts) == len(ts_detect_beg): + filename_vector.append(filenames[-1]) + else: + filename_vector.append(filenames[0]) # detection vector detect_vec = [0] * len(time_vector) @@ -327,8 +369,39 @@ def reshape_timebin( ), ) - return concat(results).sort_values(by=["start_datetime", "end_datetime", "annotator", "annotation"]).reset_index(drop=True) + return (concat(results). + sort_values(by=["start_datetime", "end_datetime", + "annotator", "annotation"]).reset_index(drop=True) + ) + +def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]: + """Get start timestamps of the wav files of each detection contained in df. + + Parameters. + ---------- + df: DataFrame + An APLOSE result DataFrame. + date_parser: str + date parser of the wav file + + Returns + ------- + List of Timestamps corresponding to the wav files' start timestamps + of each detection contained in df. + + """ + tz = get_timezone(df) + try: + return [ + to_datetime( + ts, + format=date_parser, + ).tz_localize(tz) for ts in df["filename"] + ] + except ValueError: + msg = """Could not parse timestamps from `df["filename"]`.""" + raise ValueError(msg) from None def ensure_in_list(value: str, candidates: list[str], label: str) -> None: """Check for non-valid elements of a list.""" @@ -366,10 +439,14 @@ def load_detections(filters: DetectionFilter) -> DataFrame: df = filter_by_label(df, label=filters.annotation) df = filter_by_freq(df, filters.f_min, filters.f_max) df = filter_by_score(df, filters.score) - df = reshape_timebin(df, filters.timebin_new) + filename_ts = get_filename_timestamps(df, filters.filename_format) + df = reshape_timebin(df, + timebin_new=filters.timebin_new, + timestamp_audio=filename_ts + ) annotators = get_annotators(df) - if len(annotators) > 1 and filters.user_sel in ["union", "intersection"]: + if len(annotators) > 1 and filters.user_sel in {"union", "intersection"}: df = intersection_or_union(df, user_sel=filters.user_sel) return df.sort_values(by=["start_datetime", "end_datetime"]).reset_index(drop=True) @@ -385,7 +462,7 @@ def intersection_or_union(df: DataFrame, user_sel: str) -> DataFrame: if user_sel == "all": return df - if user_sel not in ("intersection", "union"): + if user_sel not in {"intersection", "union"}: msg = "'user_sel' must be either 'intersection' or 'union'" raise ValueError(msg) diff --git a/tests/conftest.py b/tests/conftest.py index 158c269..e03bf43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -164,10 +164,11 @@ def sample_csv_timestamp(tmp_path: Path, sample_status: DataFrame) -> Path: @pytest.fixture -def sample_yaml(tmp_path: Path, - sample_csv_result: Path, - sample_csv_timestamp: Path, - ) -> Path: +def sample_yaml( + tmp_path: Path, + sample_csv_result: Path, + sample_csv_timestamp: Path, +) -> Path: yaml_content = { f"{sample_csv_result}": { "timebin_new": None, @@ -177,6 +178,7 @@ def sample_yaml(tmp_path: Path, "annotation": "lbl1", "box": True, "timestamp_file": f"{sample_csv_timestamp}", + "filename_format": "%Y_%m_%d_%H_%M_%S", "user_sel": "all", "f_min": None, "f_max": None, diff --git a/tests/test_DetectionFilters.py b/tests/test_DetectionFilters.py index 6b7dd2f..90d80ab 100644 --- a/tests/test_DetectionFilters.py +++ b/tests/test_DetectionFilters.py @@ -18,6 +18,7 @@ def test_from_yaml(sample_yaml: Path, "annotation": "lbl1", "box": True, "timestamp_file": f"{sample_csv_timestamp}", + "filename_format": "%Y_%m_%d_%H_%M_%S", "user_sel": "all", "f_min": None, "f_max": None, diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index 12eb988..68b8d20 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -1,10 +1,12 @@ +from __future__ import annotations + import csv from pathlib import Path from zoneinfo import ZoneInfo import pytest import pytz -from pandas import DataFrame, Timedelta, Timestamp, date_range, concat, to_datetime +from pandas import DataFrame, Timedelta, Timestamp, concat, to_datetime from post_processing.utils.filtering_utils import ( filter_by_annotator, @@ -296,7 +298,7 @@ def test_get_timezone_single(sample_df: DataFrame) -> None: def test_get_timezone_several(sample_df: DataFrame) -> None: new_row = { "dataset": "dataset", - "filename": "filename", + "filename": "2025_01_26_06_20_00", "start_time": 0, "end_time": 2, "start_frequency": 100, @@ -382,7 +384,7 @@ def test_no_timebin_returns_original(sample_df: DataFrame) -> None: def test_no_timebin_several_tz(sample_df: DataFrame) -> None: new_row = { "dataset": "dataset", - "filename": "filename", + "filename": "2025_01_26_06_20_00", "start_time": 0, "end_time": 2, "start_frequency": 100, @@ -398,13 +400,23 @@ def test_no_timebin_several_tz(sample_df: DataFrame) -> None: [sample_df, DataFrame([new_row])], ignore_index=False ) - - df_out = reshape_timebin(sample_df, timebin_new=None) + timestamp_wav = to_datetime(sample_df["filename"], + format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(pytz.UTC) + df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=None) assert df_out.equals(sample_df) def test_no_timebin_original_timebin(sample_df: DataFrame) -> None: - df_out = reshape_timebin(sample_df, timebin_new=Timedelta("1min")) + tz = get_timezone(sample_df) + timestamp_wav = to_datetime( + sample_df["filename"], + format="%Y_%m_%d_%H_%M_%S" + ).dt.tz_localize(tz) + df_out = reshape_timebin( + sample_df, + timestamp_audio=timestamp_wav, + timebin_new=Timedelta("1min"), + ) expected = DataFrame( { "dataset": ["sample_dataset"] * 18, @@ -486,7 +498,16 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None: def test_simple_reshape_hourly(sample_df: DataFrame) -> None: - df_out = reshape_timebin(sample_df, timebin_new=Timedelta(hours=1)) + tz = get_timezone(sample_df) + timestamp_wav = to_datetime( + sample_df["filename"], + format="%Y_%m_%d_%H_%M_%S" + ).dt.tz_localize(tz) + df_out = reshape_timebin( + sample_df, + timestamp_audio=timestamp_wav, + timebin_new=Timedelta(hours=1), + ) assert not df_out.empty assert all(df_out["end_time"] == 3600.0) assert df_out["end_frequency"].max() == sample_df["end_frequency"].max() @@ -495,7 +516,12 @@ def test_simple_reshape_hourly(sample_df: DataFrame) -> None: def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None: - df_out = reshape_timebin(sample_df, timebin_new=Timedelta(days=1)) + tz = get_timezone(sample_df) + timestamp_wav = to_datetime( + sample_df["filename"], + format="%Y_%m_%d_%H_%M_%S" + ).dt.tz_localize(tz) + df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1)) assert not df_out.empty assert all(df_out["end_time"] == 86400.0) assert df_out["start_datetime"].min() >= sample_df["start_datetime"].min().floor("D") @@ -503,14 +529,14 @@ def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None: def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None: - t0 = sample_df["start_datetime"].min().floor("30min") - t1 = sample_df["end_datetime"].max().ceil("30min") - ts_vec = list(date_range(t0, t1, freq="30min")) + tz = get_timezone(sample_df) + timestamp_wav = to_datetime(sample_df["filename"], + format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz) df_out = reshape_timebin( sample_df, - timebin_new=Timedelta(hours=1), - timestamp=ts_vec, + timestamp_audio=timestamp_wav, + timebin_new=Timedelta(hours=1) ) assert not df_out.empty @@ -519,8 +545,11 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None: def test_empty_result_when_no_matching(sample_df: DataFrame) -> None: + tz = get_timezone(sample_df) + timestamp_wav = to_datetime(sample_df["filename"], + format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(tz) with pytest.raises(ValueError, match="DataFrame is empty"): - reshape_timebin(DataFrame(), Timedelta(hours=1)) + reshape_timebin(DataFrame(), timestamp_audio=timestamp_wav, timebin_new=Timedelta(hours=1)) # %% ensure_no_invalid