diff --git a/pyproject.toml b/pyproject.toml index 9550fb6..d202dd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,9 +39,6 @@ dev = [ "coverage>=7.11.0", ] -[tool.ruff] -exclude = ["scripts/**/*.py"] - [tool.ruff.lint.flake8-copyright] author = "OSmOSE" diff --git a/src/post_processing/dataclass/data_aplose.py b/src/post_processing/dataclass/data_aplose.py index 59cb01b..3165de8 100644 --- a/src/post_processing/dataclass/data_aplose.py +++ b/src/post_processing/dataclass/data_aplose.py @@ -287,9 +287,9 @@ def set_ax( return ax - def overview(self) -> None: + def overview(self, annotator: list[str] | None = None) -> None: """Overview of an APLOSE formatted DataFrame.""" - overview(self.df) + overview(self.df, annotator) def detection_perf( self, diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index 5083c7d..d4c3301 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -36,6 +36,23 @@ def find_delimiter(file: Path) -> str: return delimiter +def filter_strong_detection( + df: DataFrame, +) -> DataFrame: + """Filter strong detections of a DataFrame.""" + if "type" in df.columns: + df = df[df["type"] == "WEAK"] + elif "is_box" in df.columns: + df = df[df["is_box"] == 0] + else: + msg = "Could not determine annotation type." + raise ValueError(msg) + if df.empty: + msg = "No weak detection found." + raise ValueError(msg) + return df + + def filter_by_time( df: DataFrame, begin: Timestamp | None, @@ -333,6 +350,8 @@ def load_detections(filters: DetectionFilter) -> DataFrame: """ df = read_dataframe(filters.detection_file) + if filters.box: + df = filter_strong_detection(df) df = filter_by_time(df, filters.begin, filters.end) df = filter_by_annotator(df, annotator=filters.annotator) df = filter_by_label(df, label=filters.annotation) diff --git a/src/post_processing/utils/metrics_utils.py b/src/post_processing/utils/metrics_utils.py index a21c618..b610d7f 100644 --- a/src/post_processing/utils/metrics_utils.py +++ b/src/post_processing/utils/metrics_utils.py @@ -18,7 +18,7 @@ def detection_perf( timestamps: list[Timestamp] | None = None, *, ref: tuple[str, str], -) -> None: +) -> tuple[float, float, float]: """Compute performances metrics for detection. Performances are computed with a reference annotator in @@ -128,6 +128,8 @@ def detection_perf( logging.info(f"Recall: {recall:.2f}") logging.info(f"F-score: {f_score:.2f}") + return precision, recall, f_score + def _map_datetimes_to_vector(df: DataFrame, timestamps: list[int]) -> ndarray: """Map datetime ranges to a binary vector indicating overlap with timestamp bins. diff --git a/src/post_processing/utils/plot_utils.py b/src/post_processing/utils/plot_utils.py index 9fdcefb..3c34fbd 100644 --- a/src/post_processing/utils/plot_utils.py +++ b/src/post_processing/utils/plot_utils.py @@ -27,7 +27,11 @@ round_begin_end_timestamps, timedelta_to_str, ) -from post_processing.utils.filtering_utils import get_max_time, get_timezone +from post_processing.utils.filtering_utils import ( + get_max_time, + get_timezone, + filter_by_annotator, +) from post_processing.utils.metrics_utils import normalize_counts_by_effort if TYPE_CHECKING: @@ -368,15 +372,20 @@ def heatmap(df: DataFrame, ax.set_xlabel(f"Time ({bin_size_str} bin)") -def overview(df: DataFrame) -> None: +def overview(df: DataFrame, annotator: list[str] | None = None) -> None: """Overview of an APLOSE formatted DataFrame. Parameters ---------- df: DataFrame The Dataframe to analyse. + annotator: list[str] + List of annotators. """ + if annotator is not None: + df = filter_by_annotator(df, annotator) + summary_label = ( df.groupby("annotation")["annotator"] # noqa: PD010 .apply(Counter) diff --git a/tests/conftest.py b/tests/conftest.py index 800b6ee..32991e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -172,7 +172,7 @@ def sample_yaml(tmp_path: Path, "end": None, "annotator": "ann1", "annotation": "lbl1", - "box": False, + "box": True, "timestamp_file": f"{sample_csv_timestamp}", "user_sel": "all", "f_min": None, diff --git a/tests/test_DataAplose.py b/tests/test_DataAplose.py index c0b8af5..5ad1b04 100644 --- a/tests/test_DataAplose.py +++ b/tests/test_DataAplose.py @@ -179,7 +179,10 @@ def test_set_ax(sample_df: DataFrame) -> None: assert isinstance(locator, mdates.HourLocator) -def test_from_yaml(sample_yaml: Path, sample_df: DataFrame) -> None: +def test_from_yaml( + sample_yaml: Path, + sample_df: DataFrame, +) -> None: df_from_yaml = DataAplose.from_yaml(file=sample_yaml).df df_expected = DataAplose(sample_df).filter_df(annotator="ann1", label="lbl1").reset_index(drop=True) assert df_from_yaml.equals(df_expected) diff --git a/tests/test_DetectionFilters.py b/tests/test_DetectionFilters.py index 286a7a8..6b7dd2f 100644 --- a/tests/test_DetectionFilters.py +++ b/tests/test_DetectionFilters.py @@ -16,7 +16,7 @@ def test_from_yaml(sample_yaml: Path, "end": None, "annotator": "ann1", "annotation": "lbl1", - "box": False, + "box": True, "timestamp_file": f"{sample_csv_timestamp}", "user_sel": "all", "f_min": None, diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index 622ad2a..93021a0 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -7,6 +7,7 @@ from post_processing.utils.filtering_utils import ( filter_by_annotator, + filter_strong_detection, filter_by_freq, filter_by_label, filter_by_score, @@ -144,6 +145,17 @@ def test_filter_by_score_missing_column(sample_df: DataFrame) -> None: filter_by_score(df, 0.5) +# filter_weak_strong_detection +def test_filter_weak_only(sample_df: DataFrame) -> None: + df = filter_strong_detection(sample_df) + assert set(df["is_box"]) == {0} + + +def test_filter_weak_empty(sample_df: DataFrame) -> None: + with pytest.raises(ValueError, match="No weak detection found"): + filter_strong_detection(sample_df[sample_df["is_box"] == 1]) + + def test_get_annotators(sample_df: DataFrame) -> None: annotators = get_annotators(sample_df) expected = sorted(set(sample_df["annotator"]))