From c1331e467b7ef1fc2b0b67363ae34004476df4bf Mon Sep 17 00:00:00 2001 From: Dhanur Sharma Date: Mon, 24 Mar 2025 13:01:15 -0500 Subject: [PATCH 1/5] Removed awscli as a dependency for celerybeat --- production.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/production.yml b/production.yml index ba306675..995ef423 100644 --- a/production.yml +++ b/production.yml @@ -58,7 +58,6 @@ services: image: sde_indexing_helper_production_celerybeat container_name: sde_indexing_helper_production_celerybeat depends_on: - - awscli - postgres ports: [] command: /start-celerybeat From d39f768ee54f5f4b5a415c26407280df74e81cbc Mon Sep 17 00:00:00 2001 From: Dhanur Sharma Date: Mon, 24 Mar 2025 13:03:06 -0500 Subject: [PATCH 2/5] Added threshold processor and config file --- inference/utils/config.py | 58 +++++++++++++++++++ inference/utils/threshold_processor.py | 79 ++++++++++++++++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 inference/utils/config.py create mode 100644 inference/utils/threshold_processor.py diff --git a/inference/utils/config.py b/inference/utils/config.py new file mode 100644 index 00000000..52e03868 --- /dev/null +++ b/inference/utils/config.py @@ -0,0 +1,58 @@ +"""Configuration settings for classification thresholds.""" + +# Configuration settings for TDAMM tag classification thresholds +# Format: "tag_name": threshold_value (0.0 to 1.0) +TDAMM_TAG_THRESHOLDS = { + "NOT_TDAMM": 0.7, # non-TDAMM + "MMA_O_BH_AGN": 0.5, # Active Galactic Nuclei + "MMA_O_BI_BBH": 0.7, # Binary Black Holes + "MMA_O_BI_BNS": 0.6, # Binary Neutron Stars + "MMA_O_BI_B": 0.7, # Binary Pulsars + "MMA_M_G_B": 0.5, # Burst + "MMA_O_BI_C": 0.7, # Cataclysmic Variables + "MMA_M_G_CBI": 0.5, # Compact Binary Inspiral + "MMA_M_G_CON": 0.5, # Continuous + "MMA_M_C": 0.8, # Cosmic Rays + "MMA_O_E": 0.7, # Exoplanets + "MMA_S_FBOT": 0.7, # Fast Blue Optical Transients + "MMA_S_F": 0.7, # Fast Radio Bursts + "MMA_M_EM_G": 0.5, # Gamma rays + "MMA_S_G": 0.8, # Gamma-ray Bursts + "MMA_M_EM_I": 0.8, # Infrared + "MMA_O_BH_IM": 0.5, # Intermediate Mass + "MMA_S_K": 0.5, # Kilonovae + "MMA_O_N_M": 0.7, # Magnetars + "MMA_M_N": 0.5, # Neutrinos + "MMA_O_BI_N": 0.5, # Neutron Star-Black Hole + "MMA_S_N": 0.8, # Novae + "MMA_M_EM_O": 0.7, # Optical + "MMA_S_P": 0.5, # Pevatrons + "MMA_O_N_PWN": 0.8, # Pulsar Wind Nebulae + "MMA_O_N_P": 0.5, # Pulsars + "MMA_M_EM_R": 0.8, # Radio + "MMA_O_BH_STM": 0.5, # Stellar Mass + "MMA_S_ST": 0.7, # Stellar flares + "MMA_M_G_S": 0.8, # Stochastic + "MMA_S_SU": 0.8, # SuperNovae + "MMA_O_BH_SUM": 0.5, # Supermassive + "MMA_O_S": 0.6, # Supernova Remnants + "MMA_M_EM_U": 0.7, # Ultraviolet + "MMA_O_BI_W": 0.7, # White Dwarf Binaries + "MMA_M_EM_X": 0.8, # X-rays +} + +# Default threshold to use if a specific tag isn't defined above +DEFAULT_TDAMM_THRESHOLD = 0.5 + +# Threshold values for different Division classifications +DIVISION_TAG_THRESHOLDS = { + "Astrophysics": 0.5, + "Biological and Physical Sciences": 0.5, + "Earth Science": 0.5, + "Heliophysics": 0.5, + "Planetary Science": 0.5, + "General": 0.5, +} + +# Default threshold for Division classification +DEFAULT_DIVISION_THRESHOLD = 0.5 diff --git a/inference/utils/threshold_processor.py b/inference/utils/threshold_processor.py new file mode 100644 index 00000000..89f7e896 --- /dev/null +++ b/inference/utils/threshold_processor.py @@ -0,0 +1,79 @@ +"""Module for processing classifications with tag-specific thresholds.""" + +from inference.utils.config import ( + DEFAULT_DIVISION_THRESHOLD, + DEFAULT_TDAMM_THRESHOLD, + DIVISION_TAG_THRESHOLDS, + TDAMM_TAG_THRESHOLDS, +) + + +class ClassificationThresholdProcessor: + """ + Generic processor for classifications using tag-specific thresholds. + Can be used with any classification system where different classes + need different confidence thresholds. + """ + + def __init__(self, thresholds: dict[str, float], default_threshold: float = 0.5): + """ + Initialize the processor with thresholds. + + Args: + thresholds: Dictionary of classification tags and their threshold values. + default_threshold: Default threshold to use if tag isn't in thresholds. + """ + self.thresholds = thresholds + self.default_threshold = default_threshold + + @classmethod + def for_tdamm(cls): + """Create a processor for TDAMM classification.""" + return cls(TDAMM_TAG_THRESHOLDS, DEFAULT_TDAMM_THRESHOLD) + + @classmethod + def for_division(cls): + """Create a processor for Division classification.""" + return cls(DIVISION_TAG_THRESHOLDS, DEFAULT_DIVISION_THRESHOLD) + + def get_threshold(self, tag: str) -> float: + """ + Get the threshold for a tag. + + Args: + tag: The tag to get threshold for + + Returns: + The threshold value as a float + """ + return self.thresholds.get(tag, self.default_threshold) + + def filter_classifications(self, classifications: dict[str, float | str]) -> dict[str, float]: + """ + Filter classifications based on their thresholds. + + Args: + classifications: Dictionary with classification keys and confidence scores + + Returns: + Dictionary with classifications that passed their thresholds + """ + result = {} + for key, confidence in classifications.items(): + # Convert confidence to float if it's a string + if isinstance(confidence, str): + try: + confidence_value = float(confidence) + except (ValueError, TypeError): + continue + else: + confidence_value = confidence + + # Get the threshold for this classification + threshold = self.get_threshold(key) + + # Keep only classifications that meet their threshold + if confidence_value >= threshold: + result[key] = confidence_value + + return result From 843dd00eb788ee27bf271ba5452007df37e8e56f Mon Sep 17 00:00:00 2001 From: Dhanur Sharma Date: Mon, 24 Mar 2025 13:03:49 -0500 Subject: [PATCH 3/5] Updated classification_utils to use the threshold_processor --- inference/utils/classification_utils.py | 28 +++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/inference/utils/classification_utils.py b/inference/utils/classification_utils.py index 37fab7d8..0e757e36 100644 --- a/inference/utils/classification_utils.py +++ b/inference/utils/classification_utils.py @@ -1,5 +1,6 @@ from django.conf import settings +from inference.utils.threshold_processor import ClassificationThresholdProcessor from sde_collections.models.collection_choice_fields import TDAMMTags @@ -18,6 +19,9 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None): if threshold is None: threshold = float(getattr(settings, "TDAMM_CLASSIFICATION_THRESHOLD")) + # Initialize the threshold processor + threshold_processor = ClassificationThresholdProcessor.for_tdamm() + selected_tags = [] # Build a mapping from simplified tag names to actual TDAMMTags values @@ -35,6 +39,7 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None): tag_mapping["supernovae"] = tag_value # Process classification results + tdamm_confidences = {} for classification_key, confidence in classification_results.items(): if isinstance(confidence, str): try: @@ -42,23 +47,28 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None): except (ValueError, TypeError): continue - if confidence < threshold: - continue - # Normalize the classification key normalized_key = classification_key.lower() + tag_value = None # Try to find a match in our mapping if normalized_key in tag_mapping: - selected_tags.append(tag_mapping[normalized_key]) + tag_value = tag_mapping[normalized_key] else: - # Try partial matching for more complex cases - for tag_key, tag_value in tag_mapping.items(): - if tag_key in normalized_key or normalized_key in tag_key: - selected_tags.append(tag_value) + # Try partial matching + for key, value in tag_mapping.items(): + if key in normalized_key or normalized_key in key: + tag_value = value break - return selected_tags + # Skip if no matching tag found + if not tag_value: + continue + + tdamm_confidences[tag_value] = confidence + + selected_tags = threshold_processor.filter_classifications(tdamm_confidences) + return list(selected_tags.keys()) def update_url_with_classification_results(url_object, classification_results, threshold=None): From 9c12335611d2b9f764517d7ef03a832b28eb055b Mon Sep 17 00:00:00 2001 From: Dhanur Sharma Date: Mon, 24 Mar 2025 13:07:52 -0500 Subject: [PATCH 4/5] Added tests for threshold processor --- inference/tests/test_threshold_processor.py | 138 ++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 inference/tests/test_threshold_processor.py diff --git a/inference/tests/test_threshold_processor.py b/inference/tests/test_threshold_processor.py new file mode 100644 index 00000000..4bb83ba7 --- /dev/null +++ b/inference/tests/test_threshold_processor.py @@ -0,0 +1,138 @@ +# inference/tests/test_threshold_processor.py +# docker-compose -f local.yml run --rm django pytest inference/tests/test_threshold_processor.py + +from unittest.mock import patch + +import pytest + +from inference.utils.threshold_processor import ClassificationThresholdProcessor + + +class TestClassificationThresholdProcessor: + """Test suite for ClassificationThresholdProcessor class.""" + + @pytest.fixture + def test_thresholds(self): + """Test thresholds for generic processor.""" + return {"TAG_A": 0.7, "TAG_B": 0.5, "TAG_C": 0.8} + + @pytest.fixture + def processor(self, test_thresholds): + """Create a ClassificationThresholdProcessor with test thresholds.""" + return ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6) + + def test_initialization(self, test_thresholds): + """Test initialization with provided thresholds.""" + processor = ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6) + assert processor.thresholds == test_thresholds + assert processor.default_threshold == 0.6 + + def test_for_tdamm_factory(self): + """Test the for_tdamm class factory method.""" + with patch("inference.utils.threshold_processor.TDAMM_TAG_THRESHOLDS", {"TAG1": 0.7}), patch( + "inference.utils.threshold_processor.DEFAULT_TDAMM_THRESHOLD", 0.5 + ): + processor = ClassificationThresholdProcessor.for_tdamm() + assert processor.thresholds == {"TAG1": 0.7} + assert processor.default_threshold == 0.5 + + def test_for_division_factory(self): + """Test the for_division class factory method.""" + with patch("inference.utils.threshold_processor.DIVISION_TAG_THRESHOLDS", {1: 0.7}), patch( + "inference.utils.threshold_processor.DEFAULT_DIVISION_THRESHOLD", 0.5 + ): + processor = ClassificationThresholdProcessor.for_division() + assert processor.thresholds == {1: 0.7} + assert processor.default_threshold == 0.5 + + def test_get_threshold_exact_match(self, processor): + """Test get_threshold with an exact tag match.""" + assert processor.get_threshold("TAG_A") == 0.7 + assert processor.get_threshold("TAG_B") == 0.5 + assert processor.get_threshold("TAG_C") == 0.8 + + def test_get_threshold_no_match(self, processor): + """Test get_threshold with a tag that doesn't exist.""" + assert processor.get_threshold("UNKNOWN_TAG") == 0.6 # default threshold + + def test_filter_classifications_all_pass(self, processor): + """Test filter_classifications where all pass their thresholds.""" + classifications = { + "TAG_A": 0.8, # 0.8 > 0.7 threshold + "TAG_B": 0.6, # 0.6 > 0.5 threshold + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 2 + assert "TAG_A" in filtered + assert "TAG_B" in filtered + + def test_filter_classifications_some_pass(self, processor): + """Test filter_classifications where some pass their thresholds.""" + classifications = { + "TAG_A": 0.6, # 0.6 < 0.7 threshold + "TAG_B": 0.6, # 0.6 > 0.5 threshold + "TAG_C": 0.9, # 0.9 > 0.8 threshold + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 2 + assert "TAG_A" not in filtered + assert "TAG_B" in filtered + assert "TAG_C" in filtered + + def test_filter_classifications_none_pass(self, processor): + """Test filter_classifications where none pass their thresholds.""" + classifications = { + "TAG_A": 0.6, # 0.6 < 0.7 threshold + "TAG_C": 0.7, # 0.7 < 0.8 threshold + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 0 + + def test_filter_classifications_default_threshold(self, processor): + """Test filter_classifications using default threshold for unknown tags.""" + classifications = { + "UNKNOWN_TAG": 0.7, # 0.7 > 0.6 default threshold + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 1 + assert "UNKNOWN_TAG" in filtered + + def test_filter_classifications_string_confidence(self, processor): + """Test filter_classifications with string confidence values.""" + classifications = { + "TAG_A": "0.8", # Should convert to float and pass + "TAG_B": "0.4", # Should convert to float and fail + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 1 + assert "TAG_A" in filtered + assert "TAG_B" not in filtered + + def test_filter_classifications_invalid_confidence(self, processor): + """Test filter_classifications with invalid confidence values.""" + classifications = { + "TAG_A": 0.8, + "TAG_B": "not a number", # Invalid + "TAG_C": 0.9, + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 2 + assert "TAG_A" in filtered + assert "TAG_B" not in filtered + assert "TAG_C" in filtered + + def test_filter_classifications_exact_threshold(self, processor): + """Test filter_classifications with confidence exactly at threshold.""" + classifications = { + "TAG_A": 0.7, # Exactly at threshold (0.7) + "TAG_B": 0.5, # Exactly at threshold (0.5) + } + filtered = processor.filter_classifications(classifications) + assert len(filtered) == 2 + assert "TAG_A" in filtered + assert "TAG_B" in filtered + + def test_filter_classifications_empty_dict(self, processor): + """Test filter_classifications with an empty dictionary.""" + filtered = processor.filter_classifications({}) + assert filtered == {} From 54fba983f37b441161cbcbc4e73beaa539d5dca3 Mon Sep 17 00:00:00 2001 From: Dhanur Sharma Date: Mon, 24 Mar 2025 13:20:42 -0500 Subject: [PATCH 5/5] Removed outdated test --- inference/tests/test_classification_utils.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/inference/tests/test_classification_utils.py b/inference/tests/test_classification_utils.py index e6cab651..2f992f3d 100644 --- a/inference/tests/test_classification_utils.py +++ b/inference/tests/test_classification_utils.py @@ -127,17 +127,6 @@ def test_complex_mappings(self): actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8) assert sorted(actual_tags) == sorted(expected_tags) - @patch("django.conf.settings.TDAMM_CLASSIFICATION_THRESHOLD", 0.75) - def test_default_threshold_from_settings(self): - """Test using the default threshold from settings""" - classification_results = {"Optical": 0.7, "Infrared": 0.8, "X-rays": 0.9} - - # With settings threshold of 0.75, Infrared and X-rays should be included - expected_tags = ["MMA_M_EM_I", "MMA_M_EM_X"] - actual_tags = map_classification_to_tdamm_tags(classification_results) # No threshold provided - - assert sorted(actual_tags) == sorted(expected_tags) - class TestUpdateUrlWithClassificationResults: """Tests for the update_url_with_classification_results function"""