diff --git a/src/post_processing/dataclass/data_aplose.py b/src/post_processing/dataclass/data_aplose.py index 98e6d9c..3c8d1e4 100644 --- a/src/post_processing/dataclass/data_aplose.py +++ b/src/post_processing/dataclass/data_aplose.py @@ -393,7 +393,6 @@ def plot( color = kwargs.get("color") season = kwargs.get("season") effort = kwargs.get("effort") - if not bin_size: msg = "'bin_size' missing for histogram plot." raise ValueError(msg) diff --git a/src/post_processing/dataclass/detection_filter.py b/src/post_processing/dataclass/detection_filter.py index d636c4c..b28c023 100644 --- a/src/post_processing/dataclass/detection_filter.py +++ b/src/post_processing/dataclass/detection_filter.py @@ -7,7 +7,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, fields from pathlib import Path from typing import TYPE_CHECKING, Literal @@ -44,6 +44,12 @@ class DetectionFilter: box: bool = False filename_format: str = None + def __getitem__(self, key: str): + """Return the value of the given key.""" + if key in {f.name for f in fields(self)}: + return getattr(self, key) + raise KeyError(key) + @classmethod def from_yaml( cls, diff --git a/src/post_processing/dataclass/recording_period.py b/src/post_processing/dataclass/recording_period.py index 4c09722..637733e 100644 --- a/src/post_processing/dataclass/recording_period.py +++ b/src/post_processing/dataclass/recording_period.py @@ -8,19 +8,15 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from osekit.config import TIMESTAMP_FORMATS_EXPORTED_FILES -from osekit.utils.timestamp_utils import strptime_from_text from pandas import ( Series, Timedelta, - cut, + date_range, + interval_range, read_csv, + to_datetime, ) -from post_processing.utils.core_utils import ( - get_time_range_and_bin_size, - localize_timestamps, -) from post_processing.utils.filtering_utils import ( find_delimiter, ) @@ -33,7 +29,7 @@ @dataclass(frozen=True) class RecordingPeriod: - """A class to handle recording periods.""" + """Represents recording effort over time, aggregated into bins.""" counts: Series timebin_origin: Timedelta @@ -42,33 +38,124 @@ class RecordingPeriod: def from_path( cls, config: DetectionFilter, - date_format: str = TIMESTAMP_FORMATS_EXPORTED_FILES, *, bin_size: Timedelta | BaseOffset, ) -> RecordingPeriod: - """Return a list of Timestamps corresponding to recording periods.""" + """Vectorized creation of recording coverage from CSV with start/end datetimes. + + This method reads a CSV with columns: + - "start_recording" + - "end_recording" + - "start_deployment" + - "end_deployment" + + It computes the **effective recording interval** as the intersection between + recording and deployment periods, builds a fine-grained timeline at + `timebin_origin` resolution, and aggregates effort into `bin_size` bins. + + Parameters + ---------- + config + Configuration object containing at least: + - `timestamp_file`: path to CSV + - `timebin_origin`: Timedelta resolution of detections + bin_size : Timedelta or BaseOffset + Size of the aggregation bin (e.g., Timedelta("1H") or "1D"). + + Returns + ------- + RecordingPeriod + Object containing `counts` (Series indexed by IntervalIndex) and + `timebin_origin`. + + """ + # Read CSV and parse datetime columns timestamp_file = config.timestamp_file delim = find_delimiter(timestamp_file) - timestamp_df = read_csv(timestamp_file, delimiter=delim) - - if "timestamp" in timestamp_df.columns: - msg = "Parsing 'timestamp' column not implemented yet." - raise NotImplementedError(msg) - - if "filename" in timestamp_df.columns: - timestamps = [ - strptime_from_text(ts, date_format) - for ts in timestamp_df["filename"] - ] - timestamps = localize_timestamps(timestamps, config.timezone) - time_vector, bin_size = get_time_range_and_bin_size(timestamps, bin_size) - - binned = cut(timestamps, time_vector) - max_annot = bin_size / config.timebin_origin - - return cls(counts=binned.value_counts().sort_index().clip(upper=max_annot), - timebin_origin=config.timebin_origin, - ) - - msg = "Could not parse timestamps." - raise ValueError(msg) + df = read_csv( + config.timestamp_file, + parse_dates=[ + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + ], + delimiter=delim, + ) + + if df.empty: + msg = "CSV is empty." + raise ValueError(msg) + + # Ensure all required columns are present + required_columns = { + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + } + + missing = required_columns - set(df.columns) + + if missing: + msg = f"CSV is missing required columns: {', '.join(sorted(missing))}" + raise ValueError(msg) + + # Normalize timezones: convert to UTC, then remove tz info (naive) + for col in [ + "start_recording", + "end_recording", + "start_deployment", + "end_deployment", + ]: + df[col] = to_datetime(df[col], utc=True).dt.tz_convert(None) + + # Compute effective recording intervals (intersection) + df["effective_start_recording"] = df[ + ["start_recording", "start_deployment"] + ].max(axis=1) + + df["effective_end_recording"] = df[ + ["end_recording", "end_deployment"] + ].min(axis=1) + + # Remove rows with no actual recording interval + df = df.loc[df["effective_start_recording"] < df["effective_end_recording"]].copy() + + if df.empty: + msg = "No valid recording intervals after deployment intersection." + raise ValueError(msg) + + # Build fine-grained timeline at `timebin_origin` resolution + origin = config.timebin_origin + time_index = date_range( + start=df["effective_start_recording"].min(), + end=df["effective_end_recording"].max(), + freq=origin, + ) + + # Initialize effort vector (0 = no recording, 1 = recording) + # Compare each timestamp to all intervals in a vectorized manner + effort = Series(0, index=time_index) + + # Vectorized interval coverage + t_vals = time_index.to_numpy()[:, None] + start_vals = df["effective_start_recording"].to_numpy() + end_vals = df["effective_end_recording"].to_numpy() + + # Boolean matrix: True if the timestamp is within any recording interval + covered = (t_vals >= start_vals) & (t_vals < end_vals) + effort[:] = covered.any(axis=1).astype(int) + + # Aggregate effort into user-defined bin_size + counts = effort.resample(bin_size).sum() + + # Replace index with IntervalIndex for downstream compatibility + counts.index = interval_range( + start=counts.index[0], + periods=len(counts), + freq=bin_size, + closed="left", + ) + + return cls(counts=counts, timebin_origin=origin) diff --git a/src/post_processing/utils/core_utils.py b/src/post_processing/utils/core_utils.py index 5a831e1..9457ef2 100644 --- a/src/post_processing/utils/core_utils.py +++ b/src/post_processing/utils/core_utils.py @@ -11,7 +11,7 @@ from astral.sun import sunrise, sunset from matplotlib import pyplot as plt from osekit.config import TIMESTAMP_FORMAT_AUDIO_FILE -from osekit.utils.timestamp_utils import strptime_from_text, strftime_osmose_format +from osekit.utils.timestamp_utils import strftime_osmose_format, strptime_from_text from pandas import ( DataFrame, DatetimeIndex, @@ -255,7 +255,6 @@ def add_weak_detection( new_line.append(np.nan) df.loc[df.index.max() + 1] = new_line - return df.sort_values(by=["start_datetime", "annotator"]).reset_index(drop=True) @@ -509,11 +508,10 @@ def get_time_range_and_bin_size( if isinstance(bin_size, Timedelta): return timestamp_range, bin_size - elif isinstance(bin_size, BaseOffset): + if isinstance(bin_size, BaseOffset): return timestamp_range, timestamp_range[1] - timestamp_range[0] - else: - msg = "bin_size must be a Timedelta or BaseOffset." - raise TypeError(msg) + msg = "bin_size must be a Timedelta or BaseOffset." + raise TypeError(msg) def round_begin_end_timestamps( diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index c391ff6..faf2fd7 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING import pytz +from osekit.utils.timestamp_utils import strptime_from_text from pandas import ( DataFrame, Timedelta, @@ -509,8 +510,8 @@ def reshape_timebin( timebin_new: Timedelta The size of the new time bin. timestamp_audio: list[Timestamp] - A list of Timestamp objects corresponding to the shape - in which the data should be reshaped. + A list of Timestamp objects corresponding to the start of each wav + that corresponds to a detection Returns ------- @@ -570,16 +571,17 @@ def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]: """ tz = get_timezone(df) - try: - return [ - to_datetime( - ts, - format=date_parser, - ).tz_localize(tz) for ts in df["filename"] - ] - except ValueError: - msg = """Could not parse timestamps from `df["filename"]`.""" - raise ValueError(msg) from None + timestamps = [ + strptime_from_text( + ts, + datetime_template=date_parser, + ) for ts in df["filename"] + ] + + if all(t.tz is None for t in timestamps): + timestamps = [t.tz_localize(tz) for t in timestamps] + + return timestamps def ensure_in_list(value: str, candidates: list[str], label: str) -> None: diff --git a/src/post_processing/utils/plot_utils.py b/src/post_processing/utils/plot_utils.py index f21e343..f5ad0df 100644 --- a/src/post_processing/utils/plot_utils.py +++ b/src/post_processing/utils/plot_utils.py @@ -11,9 +11,18 @@ import numpy as np from matplotlib import dates as mdates from matplotlib.dates import num2date +from matplotlib.patches import Patch from matplotlib.ticker import PercentFormatter from numpy import ceil, histogram, polyfit -from pandas import DataFrame, DatetimeIndex, Index, Timedelta, Timestamp, date_range +from pandas import ( + DataFrame, + DatetimeIndex, + Index, + Series, + Timedelta, + Timestamp, + date_range, +) from pandas.tseries import frequencies from scipy.stats import pearsonr from seaborn import scatterplot @@ -28,11 +37,10 @@ timedelta_to_str, ) from post_processing.utils.filtering_utils import ( + filter_by_annotator, get_max_time, get_timezone, - filter_by_annotator, ) -from post_processing.utils.metrics_utils import normalize_counts_by_effort if TYPE_CHECKING: from datetime import tzinfo @@ -107,9 +115,6 @@ def histo( else: legend_labels = None - if effort: - normalize_counts_by_effort(df, effort, time_bin) - n_groups = len(labels) if legend_labels else 1 bar_width = bin_size / n_groups bin_starts = mdates.date2num(df.index) @@ -128,6 +133,8 @@ def histo( bar_kwargs["label"] = legend_labels[i] ax.bar(bin_starts + offset, df.iloc[:, i], **bar_kwargs) + if kwargs.get("show_recording_OFF"): + ax.set_facecolor("lightgrey") if len(df.columns) > 1 and legend: ax.legend(labels=legend_labels, bbox_to_anchor=(1.01, 1), loc="upper left") @@ -138,7 +145,7 @@ def histo( f" - bin size: {bin_size_str})" ) ax.set_ylabel(y_label) - set_y_axis_to_percentage(ax) if effort else set_dynamic_ylim(ax, df) + # set_y_axis_to_percentage(ax) if effort else set_dynamic_ylim(ax, df) set_plot_title(ax, annotators, labels) ax.set_xlim(begin, end) @@ -659,21 +666,63 @@ def shade_no_effort( """ + """Shade areas of the plot where no observation effort was made.""" width_days = bar_width.total_seconds() / 86400 - no_effort_bins = bin_starts[observed.counts.reindex(bin_starts) == 0] - for ts in no_effort_bins: - start = mdates.date2num(ts) - ax.axvspan(start, start + width_days, color="grey", alpha=0.08, zorder=1.5) - x_min, x_max = ax.get_xlim() - data_min = mdates.date2num(bin_starts[0]) - data_max = mdates.date2num(bin_starts[-1]) + width_days - - if x_min < data_min: - ax.axvspan(x_min, data_min, color="grey", alpha=0.08, zorder=1.5) - if x_max > data_max: - ax.axvspan(data_max, x_max, color="grey", alpha=0.08, zorder=1.5) - ax.set_xlim(x_min, x_max) + # Convert effort IntervalIndex → DatetimeIndex (bin starts) + effort_by_start = Series( + observed.counts.values, + index=[i.left for i in observed.counts.index], + ).tz_localize("UTC") + + # Align effort to plotting bins + effort_aligned = effort_by_start.reindex(bin_starts) + max_effort = bar_width / observed.timebin_origin + effort_fraction = effort_aligned / max_effort + + no_effort = effort_fraction == 0 + partial_effort = (effort_fraction > 0) & (effort_fraction < 1) + + # Draw partial effort first (lighter) + for ts in bin_starts[partial_effort]: + start = mdates.date2num(ts - bar_width) + ax.axvspan( + start, + start + width_days, + facecolor="0.65", + alpha=0.1, + linewidth=0, + zorder=0, + label="partial data", + ) + + # Draw no effort on top (darker) + for ts in bin_starts[no_effort]: + start = mdates.date2num(ts - bar_width) + ax.axvspan( + start, + start + width_days, + facecolor="0.45", + alpha=0.15, + linewidth=0, + zorder=0, + label="no data", + ) + + handles = [] + + if partial_effort.any(): + handles.append( + Patch(facecolor="0.65", alpha=0.1, label="partial data") + ) + + if no_effort.any(): + handles.append( + Patch(facecolor="0.45", alpha=0.15, label="no data") + ) + + if handles: + ax.legend(handles=handles) def add_sunrise_sunset(ax: Axes, lat: float, lon: float, tz: tzinfo) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index e03bf43..a6299e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ import yaml from osekit.utils.timestamp_utils import strftime_osmose_format from pandas import DataFrame, read_csv +from pandas.tseries import frequencies SAMPLE = """dataset,filename,start_time,end_time,start_frequency,end_frequency,annotation,annotator,start_datetime,end_datetime,type,score sample_dataset,2025_01_25_06_20_00,0.0,10.0,0.0,72000.0,lbl2,ann2,2025-01-25T06:20:00.000+00:00,2025-01-25T06:20:10.000+00:00,WEAK,0.11 @@ -122,8 +123,6 @@ """ - - STATUS = """dataset,filename,ann1,ann2,ann3,ann4,ann5,ann6 sample_dataset,2025_01_25_06_20_00,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED sample_dataset,2025_01_25_06_20_10,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED @@ -134,6 +133,14 @@ sample_dataset,2025_01_26_06_20_20,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED,FINISHED """ +# --------------------------------------------------------------------------- +# Fake recording planning CSV used for tests +# --------------------------------------------------------------------------- +RECORDING_PLANNING_CSV = """start_recording,end_recording,start_deployment,end_deployment +2024-01-01 00:00:00+0000,2024-04-09 02:00:00+0000,2024-01-02 00:00:00+0000,2024-04-30 02:00:00+0000 +2024-04-30 01:00:00+0000,2024-07-14 06:00:00+0000,2024-04-30 02:00:00+0000,2024-07-06 14:00:00+0000 +""" + @pytest.fixture def sample_df() -> DataFrame: @@ -228,3 +235,21 @@ def create_file(path: Path, size: int = 2048): create_file(nested / "file4.wav") (tmp_path / "ignore.txt").write_text("not audio") return tmp_path + + +@pytest.fixture +def recording_planning_csv(tmp_path) -> Path: + """Create a temporary CSV file simulating a recording planning.""" + path = tmp_path / "recording_planning.csv" + path.write_text(RECORDING_PLANNING_CSV) + return path + + +@pytest.fixture +def recording_planning_config(recording_planning_csv): + """Minimal config object compatible with RecordingPeriod.from_path.""" + class RecordingPlanningConfig: + timestamp_file: Path = recording_planning_csv + timebin_origin = frequencies.to_offset("1min") + + return RecordingPlanningConfig() diff --git a/tests/test_DataAplose.py b/tests/test_DataAplose.py index 5ad1b04..9b9516c 100644 --- a/tests/test_DataAplose.py +++ b/tests/test_DataAplose.py @@ -19,6 +19,7 @@ def test_data_aplose_init(sample_df: DataFrame) -> None: assert data.begin == sample_df["start_datetime"].min() assert data.end == sample_df["end_datetime"].max() + def test_filter_df_single_pair(sample_df: DataFrame) -> None: data = DataAplose(sample_df) filtered_data = data.filter_df(annotator="ann1", label="lbl1") @@ -30,17 +31,19 @@ def test_filter_df_single_pair(sample_df: DataFrame) -> None: ].reset_index(drop=True) assert filtered_data.equals(expected) + def test_change_tz(sample_df: DataFrame) -> None: data = DataAplose(sample_df) - new_tz = 'Etc/GMT-7' + new_tz = "Etc/GMT-7" data.change_tz(new_tz) - start_dt = data.df['start_datetime'] - end_dt = data.df['end_datetime'] + start_dt = data.df["start_datetime"] + end_dt = data.df["end_datetime"] assert all(ts.tz.zone == new_tz for ts in start_dt), f"The detection start timestamps have to be in {new_tz} timezone" assert all(ts.tz.zone == new_tz for ts in end_dt), f"The detection end timestamps have to be in {new_tz} timezone" assert data.begin.tz.zone == new_tz, f"The begin value of the DataAplose has to be in {new_tz} timezone" assert data.end.tz.zone == new_tz, f"The end value of the DataAplose has to be in {new_tz} timezone" + def test_filter_df_multiple_pairs(sample_df: DataFrame) -> None: data = DataAplose(sample_df) filtered_data = data.filter_df(annotator=["ann1", "ann2"], label=["lbl1", "lbl2"]) diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index a1a3d73..e72e482 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -8,6 +8,8 @@ from post_processing.dataclass.data_aplose import DataAplose from post_processing.utils.core_utils import ( + add_recording_period, + add_season_period, add_weak_detection, get_coordinates, get_count, @@ -15,13 +17,11 @@ get_season, get_sun_times, get_time_range_and_bin_size, + json2df, localize_timestamps, round_begin_end_timestamps, - timedelta_to_str, - add_season_period, - add_recording_period, set_bar_height, - json2df, + timedelta_to_str, ) @@ -409,10 +409,11 @@ def test_add_season_no_data() -> None: # %% add_recording_period + def test_add_recording_period_valid() -> None: fig, ax = plt.subplots() start = Timestamp("2025-01-01T00:00:00+00:00") - stop = Timestamp("2025-01-02T00:00:00+00:00") + stop = Timestamp("2025-01-02T00:00:00+00:00") ts = date_range(start=start, end=stop, freq="H", tz="UTC") values = list(range(len(ts))) @@ -423,7 +424,7 @@ def test_add_recording_period_valid() -> None: [ Timestamp("2025-01-01T00:00:00+00:00"), Timestamp("2025-01-02T00:00:00+00:00"), - ] + ], ], columns=["deployment_date", "recovery_date"], ) @@ -438,6 +439,7 @@ def test_add_recording_period_no_data() -> None: # %% set_bar_height + def test_set_bar_height_valid() -> None: fig, ax = plt.subplots() start = Timestamp("2025-01-01T00:00:00+00:00") @@ -457,6 +459,7 @@ def test_set_bar_height_no_data() -> None: # %% json2df + def test_json2df_valid(tmp_path): fake_json = { "deployment_date": "2025-01-01T00:00:00+00:00", @@ -474,9 +477,9 @@ def test_json2df_valid(tmp_path): [ Timestamp("2025-01-01T00:00:00+00:00"), Timestamp("2025-01-02T00:00:00+00:00"), - ] + ], ], columns=["deployment_date", "recovery_date"], ) - assert df.equals(expected) \ No newline at end of file + assert df.equals(expected) diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index 95fd987..3ec3760 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -77,7 +77,7 @@ def test_find_delimiter_unsupported_delimiter(tmp_path: Path) -> None: with pytest.raises( ValueError, - match=r"unsupported delimiter '&'" + match=r"unsupported delimiter '&'", ): find_delimiter(file) @@ -199,6 +199,7 @@ def test_filter_by_freq_valid(sample_df: DataFrame, f_min, f_max): if f_max is not None: assert (result["end_frequency"] <= f_max).all() + @pytest.mark.parametrize( "f_min, f_max, expected_msg", [ @@ -216,8 +217,6 @@ def test_filter_by_freq_valid(sample_df: DataFrame, f_min, f_max): ), ], ) - - def test_filter_by_freq_out_of_range(sample_df: DataFrame, f_min, f_max, expected_msg): with pytest.raises(ValueError, match=expected_msg): filter_by_freq(sample_df, f_min=f_min, f_max=f_max) @@ -331,7 +330,7 @@ def test_get_timezone_several(sample_df: DataFrame) -> None: } sample_df = concat( [sample_df, DataFrame([new_row])], - ignore_index=False + ignore_index=False, ) tz = get_timezone(sample_df) assert len(tz) == 2 @@ -340,6 +339,7 @@ def test_get_timezone_several(sample_df: DataFrame) -> None: # %% read DataFrame + def test_read_dataframe_comma_delimiter(tmp_path: Path) -> None: csv_file = tmp_path / "test.csv" csv_file.write_text( @@ -417,7 +417,7 @@ def test_no_timebin_several_tz(sample_df: DataFrame) -> None: } sample_df = concat( [sample_df, DataFrame([new_row])], - ignore_index=False + ignore_index=False, ) timestamp_wav = to_datetime(sample_df["filename"], format="%Y_%m_%d_%H_%M_%S").dt.tz_localize(pytz.UTC) @@ -429,7 +429,7 @@ def test_no_timebin_original_timebin(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) timestamp_wav = to_datetime( sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S" + format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) df_out = reshape_timebin( sample_df, @@ -520,7 +520,7 @@ def test_simple_reshape_hourly(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) timestamp_wav = to_datetime( sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S" + format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) df_out = reshape_timebin( sample_df, @@ -538,7 +538,7 @@ def test_reshape_daily_multiple_bins(sample_df: DataFrame) -> None: tz = get_timezone(sample_df) timestamp_wav = to_datetime( sample_df["filename"], - format="%Y_%m_%d_%H_%M_%S" + format="%Y_%m_%d_%H_%M_%S", ).dt.tz_localize(tz) df_out = reshape_timebin(sample_df, timestamp_audio=timestamp_wav, timebin_new=Timedelta(days=1)) assert not df_out.empty @@ -555,7 +555,7 @@ def test_with_manual_timestamps_vector(sample_df: DataFrame) -> None: df_out = reshape_timebin( sample_df, timestamp_audio=timestamp_wav, - timebin_new=Timedelta(hours=1) + timebin_new=Timedelta(hours=1), ) assert not df_out.empty @@ -589,6 +589,7 @@ def test_ensure_no_invalid_with_elements() -> None: assert "bar" in str(exc_info.value) assert "columns" in str(exc_info.value) + def test_ensure_no_invalid_single_element() -> None: invalid_items = ["baz"] with pytest.raises(ValueError) as exc_info: @@ -598,6 +599,7 @@ def test_ensure_no_invalid_single_element() -> None: # %% intersection / union + def test_intersection(sample_df) -> None: df_result = intersection_or_union(sample_df[sample_df["annotator"].isin(["ann1", "ann2"])], user_sel="intersection") @@ -628,7 +630,7 @@ def test_not_enough_annotators_raises() -> None: "annotation": ["cat"], "start_datetime": to_datetime(["2025-01-01 10:00"]), "end_datetime": to_datetime(["2025-01-01 10:01"]), - "annotator": ["A"] + "annotator": ["A"], }) with pytest.raises(ValueError, match="Not enough annotators detected"): - intersection_or_union(df_single_annotator, user_sel="intersection") \ No newline at end of file + intersection_or_union(df_single_annotator, user_sel="intersection") diff --git a/tests/test_glider_utils.py b/tests/test_glider_utils.py index 12d83df..d0247c5 100644 --- a/tests/test_glider_utils.py +++ b/tests/test_glider_utils.py @@ -56,7 +56,7 @@ def test_get_position_from_timestamp(nav_df: DataFrame) -> None: def test_plot_detections_with_nav_data( df_detections: DataFrame, - nav_df: DataFrame + nav_df: DataFrame, ) -> None: plot_detections_with_nav_data( df=df_detections, diff --git a/tests/test_metric_utils.py b/tests/test_metric_utils.py index 34ce769..35717e7 100644 --- a/tests/test_metric_utils.py +++ b/tests/test_metric_utils.py @@ -3,6 +3,7 @@ from post_processing.utils.metrics_utils import detection_perf + def test_detection_perf(sample_df: DataFrame) -> None: try: detection_perf(df=sample_df[sample_df["annotator"].isin(["ann1", "ann4"])], ref=("ann1", "lbl1")) @@ -12,4 +13,4 @@ def test_detection_perf(sample_df: DataFrame) -> None: def test_detection_perf_one_annotator(sample_df: DataFrame) -> None: with pytest.raises(ValueError, match="Two annotators needed"): - detection_perf(df=sample_df[sample_df["annotator"] == "ann1"], ref=("ann1", "lbl1")) \ No newline at end of file + detection_perf(df=sample_df[sample_df["annotator"] == "ann1"], ref=("ann1", "lbl1")) diff --git a/tests/test_plot_utils.py b/tests/test_plot_utils.py index d7392cf..ffabbbd 100644 --- a/tests/test_plot_utils.py +++ b/tests/test_plot_utils.py @@ -1,13 +1,14 @@ + import matplotlib.pyplot as plt import pytest from matplotlib.ticker import PercentFormatter from numpy import arange, testing from post_processing.utils.plot_utils import ( - overview, _wrap_xtick_labels, - set_y_axis_to_percentage, get_legend, + overview, + set_y_axis_to_percentage, ) @@ -103,4 +104,4 @@ def test_lists_and_strings_combined(): labels = ["Label1", "Label2"] result = get_legend(annotators, labels) expected = ["Alice\nLabel1", "Bob\nLabel2"] - assert result == expected \ No newline at end of file + assert result == expected diff --git a/tests/test_recording_period.py b/tests/test_recording_period.py new file mode 100644 index 0000000..41b9e59 --- /dev/null +++ b/tests/test_recording_period.py @@ -0,0 +1,64 @@ +import pandas as pd +from pandas.tseries import frequencies + +from post_processing.dataclass.recording_period import RecordingPeriod + + +def test_recording_period_with_gaps(recording_planning_config): + """RecordingPeriod correctly represents long gaps with no recording effort. + + The planning contains two recording blocks separated by ~3 weeks with no + recording at all. Weekly aggregation must reflect: + - weeks with full effort, + - weeks with partial effort, + - weeks with zero effort. + """ + recording_period = RecordingPeriod.from_path( + config=recording_planning_config, + bin_size=frequencies.to_offset("1W"), + ) + + counts = recording_period.counts + + # ------------------------------------------------------------------ + # Structural checks + # ------------------------------------------------------------------ + assert not counts.empty + assert counts.index.is_interval() + assert counts.min() >= 0 + + # One week = 7 * 24 hours (origin = 1min) + full_week_minutes = 7 * 24 * 60 + + # ------------------------------------------------------------------ + # Helper: find the bin covering a given timestamp + # ------------------------------------------------------------------ + def bin_covering(ts: pd.Timestamp) -> pd.Interval: + for interval in counts.index: + if interval.left <= ts < interval.right: + return interval + raise AssertionError(f"No bin covers timestamp {ts}") + + # ------------------------------------------------------------------ + # Week fully inside the long gap → zero effort + # ------------------------------------------------------------------ + gap_ts = pd.Timestamp("2024-04-21") + + gap_bin = bin_covering(gap_ts) + assert counts.loc[gap_bin] == 0 + + # ------------------------------------------------------------------ + # Week fully inside recording → full effort + # ------------------------------------------------------------------ + full_effort_ts = pd.Timestamp("2024-02-04") + + full_bin = bin_covering(full_effort_ts) + assert counts.loc[full_bin] == full_week_minutes + + # ------------------------------------------------------------------ + # Week overlapping recording stop → partial effort + # ------------------------------------------------------------------ + partial_ts = pd.Timestamp("2024-04-14") + + partial_bin = bin_covering(partial_ts) + assert counts.loc[partial_bin] == 1560