Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions inference/tests/test_classification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,6 @@ def test_complex_mappings(self):
actual_tags = map_classification_to_tdamm_tags(classification_results, threshold=0.8)
assert sorted(actual_tags) == sorted(expected_tags)

@patch("django.conf.settings.TDAMM_CLASSIFICATION_THRESHOLD", 0.75)
def test_default_threshold_from_settings(self):
"""Test using the default threshold from settings"""
classification_results = {"Optical": 0.7, "Infrared": 0.8, "X-rays": 0.9}

# With settings threshold of 0.75, Infrared and X-rays should be included
expected_tags = ["MMA_M_EM_I", "MMA_M_EM_X"]
actual_tags = map_classification_to_tdamm_tags(classification_results) # No threshold provided

assert sorted(actual_tags) == sorted(expected_tags)


class TestUpdateUrlWithClassificationResults:
"""Tests for the update_url_with_classification_results function"""
Expand Down
138 changes: 138 additions & 0 deletions inference/tests/test_threshold_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# inference/tests/test_threshold_processor.py
# docker-compose -f local.yml run --rm django pytest inference/tests/test_threshold_processor.py

from unittest.mock import patch

import pytest

from inference.utils.threshold_processor import ClassificationThresholdProcessor


class TestClassificationThresholdProcessor:
"""Test suite for ClassificationThresholdProcessor class."""

@pytest.fixture
def test_thresholds(self):
"""Test thresholds for generic processor."""
return {"TAG_A": 0.7, "TAG_B": 0.5, "TAG_C": 0.8}

@pytest.fixture
def processor(self, test_thresholds):
"""Create a ClassificationThresholdProcessor with test thresholds."""
return ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6)

def test_initialization(self, test_thresholds):
"""Test initialization with provided thresholds."""
processor = ClassificationThresholdProcessor(test_thresholds, default_threshold=0.6)
assert processor.thresholds == test_thresholds
assert processor.default_threshold == 0.6

def test_for_tdamm_factory(self):
"""Test the for_tdamm class factory method."""
with patch("inference.utils.threshold_processor.TDAMM_TAG_THRESHOLDS", {"TAG1": 0.7}), patch(
"inference.utils.threshold_processor.DEFAULT_TDAMM_THRESHOLD", 0.5
):
processor = ClassificationThresholdProcessor.for_tdamm()
assert processor.thresholds == {"TAG1": 0.7}
assert processor.default_threshold == 0.5

def test_for_division_factory(self):
"""Test the for_division class factory method."""
with patch("inference.utils.threshold_processor.DIVISION_TAG_THRESHOLDS", {1: 0.7}), patch(
"inference.utils.threshold_processor.DEFAULT_DIVISION_THRESHOLD", 0.5
):
processor = ClassificationThresholdProcessor.for_division()
assert processor.thresholds == {1: 0.7}
assert processor.default_threshold == 0.5

def test_get_threshold_exact_match(self, processor):
"""Test get_threshold with an exact tag match."""
assert processor.get_threshold("TAG_A") == 0.7
assert processor.get_threshold("TAG_B") == 0.5
assert processor.get_threshold("TAG_C") == 0.8

def test_get_threshold_no_match(self, processor):
"""Test get_threshold with a tag that doesn't exist."""
assert processor.get_threshold("UNKNOWN_TAG") == 0.6 # default threshold

def test_filter_classifications_all_pass(self, processor):
"""Test filter_classifications where all pass their thresholds."""
classifications = {
"TAG_A": 0.8, # 0.8 > 0.7 threshold
"TAG_B": 0.6, # 0.6 > 0.5 threshold
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 2
assert "TAG_A" in filtered
assert "TAG_B" in filtered

def test_filter_classifications_some_pass(self, processor):
"""Test filter_classifications where some pass their thresholds."""
classifications = {
"TAG_A": 0.6, # 0.6 < 0.7 threshold
"TAG_B": 0.6, # 0.6 > 0.5 threshold
"TAG_C": 0.9, # 0.9 > 0.8 threshold
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 2
assert "TAG_A" not in filtered
assert "TAG_B" in filtered
assert "TAG_C" in filtered

def test_filter_classifications_none_pass(self, processor):
"""Test filter_classifications where none pass their thresholds."""
classifications = {
"TAG_A": 0.6, # 0.6 < 0.7 threshold
"TAG_C": 0.7, # 0.7 < 0.8 threshold
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 0

def test_filter_classifications_default_threshold(self, processor):
"""Test filter_classifications using default threshold for unknown tags."""
classifications = {
"UNKNOWN_TAG": 0.7, # 0.7 > 0.6 default threshold
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 1
assert "UNKNOWN_TAG" in filtered

def test_filter_classifications_string_confidence(self, processor):
"""Test filter_classifications with string confidence values."""
classifications = {
"TAG_A": "0.8", # Should convert to float and pass
"TAG_B": "0.4", # Should convert to float and fail
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 1
assert "TAG_A" in filtered
assert "TAG_B" not in filtered

def test_filter_classifications_invalid_confidence(self, processor):
"""Test filter_classifications with invalid confidence values."""
classifications = {
"TAG_A": 0.8,
"TAG_B": "not a number", # Invalid
"TAG_C": 0.9,
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 2
assert "TAG_A" in filtered
assert "TAG_B" not in filtered
assert "TAG_C" in filtered

def test_filter_classifications_exact_threshold(self, processor):
"""Test filter_classifications with confidence exactly at threshold."""
classifications = {
"TAG_A": 0.7, # Exactly at threshold (0.7)
"TAG_B": 0.5, # Exactly at threshold (0.5)
}
filtered = processor.filter_classifications(classifications)
assert len(filtered) == 2
assert "TAG_A" in filtered
assert "TAG_B" in filtered

def test_filter_classifications_empty_dict(self, processor):
"""Test filter_classifications with an empty dictionary."""
filtered = processor.filter_classifications({})
assert filtered == {}
28 changes: 19 additions & 9 deletions inference/utils/classification_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.conf import settings

from inference.utils.threshold_processor import ClassificationThresholdProcessor
from sde_collections.models.collection_choice_fields import TDAMMTags


Expand All @@ -18,6 +19,9 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
if threshold is None:
threshold = float(getattr(settings, "TDAMM_CLASSIFICATION_THRESHOLD"))

# Initialize the threshold processor
threshold_processor = ClassificationThresholdProcessor.for_tdamm()

selected_tags = []

# Build a mapping from simplified tag names to actual TDAMMTags values
Expand All @@ -35,30 +39,36 @@ def map_classification_to_tdamm_tags(classification_results, threshold=None):
tag_mapping["supernovae"] = tag_value

# Process classification results
tdamm_confidences = {}
for classification_key, confidence in classification_results.items():
if isinstance(confidence, str):
try:
confidence = float(confidence)
except (ValueError, TypeError):
continue

if confidence < threshold:
continue

# Normalize the classification key
normalized_key = classification_key.lower()
tag_value = None

# Try to find a match in our mapping
if normalized_key in tag_mapping:
selected_tags.append(tag_mapping[normalized_key])
tag_value = tag_mapping[normalized_key]
else:
# Try partial matching for more complex cases
for tag_key, tag_value in tag_mapping.items():
if tag_key in normalized_key or normalized_key in tag_key:
selected_tags.append(tag_value)
# Try partial matching
for key, value in tag_mapping.items():
if key in normalized_key or normalized_key in key:
tag_value = value
break

return selected_tags
# Skip if no matching tag found
if not tag_value:
continue

tdamm_confidences[tag_value] = confidence

selected_tags = threshold_processor.filter_classifications(tdamm_confidences)
return list(selected_tags.keys())


def update_url_with_classification_results(url_object, classification_results, threshold=None):
Expand Down
58 changes: 58 additions & 0 deletions inference/utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Configuration settings for classification thresholds."""

# Configuration settings for TDAMM tag classification thresholds
# Format: "tag_name": threshold_value (0.0 to 1.0)
TDAMM_TAG_THRESHOLDS = {
"NOT_TDAMM": 0.7, # non-TDAMM
"MMA_O_BH_AGN": 0.5, # Active Galactic Nuclei
"MMA_O_BI_BBH": 0.7, # Binary Black Holes
"MMA_O_BI_BNS": 0.6, # Binary Neutron Stars
"MMA_O_BI_B": 0.7, # Binary Pulsars
"MMA_M_G_B": 0.5, # Burst
"MMA_O_BI_C": 0.7, # Cataclysmic Variables
"MMA_M_G_CBI": 0.5, # Compact Binary Inspiral
"MMA_M_G_CON": 0.5, # Continuous
"MMA_M_C": 0.8, # Cosmic Rays
"MMA_O_E": 0.7, # Exoplanets
"MMA_S_FBOT": 0.7, # Fast Blue Optical Transients
"MMA_S_F": 0.7, # Fast Radio Bursts
"MMA_M_EM_G": 0.5, # Gamma rays
"MMA_S_G": 0.8, # Gamma-ray Bursts
"MMA_M_EM_I": 0.8, # Infrared
"MMA_O_BH_IM": 0.5, # Intermediate Mass
"MMA_S_K": 0.5, # Kilonovae
"MMA_O_N_M": 0.7, # Magnetars
"MMA_M_N": 0.5, # Neutrinos
"MMA_O_BI_N": 0.5, # Neutron Star-Black Hole
"MMA_S_N": 0.8, # Novae
"MMA_M_EM_O": 0.7, # Optical
"MMA_S_P": 0.5, # Pevatrons
"MMA_O_N_PWN": 0.8, # Pulsar Wind Nebulae
"MMA_O_N_P": 0.5, # Pulsars
"MMA_M_EM_R": 0.8, # Radio
"MMA_O_BH_STM": 0.5, # Stellar Mass
"MMA_S_ST": 0.7, # Stellar flares
"MMA_M_G_S": 0.8, # Stochastic
"MMA_S_SU": 0.8, # SuperNovae
"MMA_O_BH_SUM": 0.5, # Supermassive
"MMA_O_S": 0.6, # Supernova Remnants
"MMA_M_EM_U": 0.7, # Ultraviolet
"MMA_O_BI_W": 0.7, # White Dwarf Binaries
"MMA_M_EM_X": 0.8, # X-rays
}

# Default threshold to use if a specific tag isn't defined above
DEFAULT_TDAMM_THRESHOLD = 0.5

# Threshold values for different Division classifications
DIVISION_TAG_THRESHOLDS = {
"Astrophysics": 0.5,
"Biological and Physical Sciences": 0.5,
"Earth Science": 0.5,
"Heliophysics": 0.5,
"Planetary Science": 0.5,
"General": 0.5,
}

# Default threshold for Division classification
DEFAULT_DIVISION_THRESHOLD = 0.5
79 changes: 79 additions & 0 deletions inference/utils/threshold_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Module for processing classifications with tag-specific thresholds."""

from inference.utils.config import (
DEFAULT_DIVISION_THRESHOLD,
DEFAULT_TDAMM_THRESHOLD,
DIVISION_TAG_THRESHOLDS,
TDAMM_TAG_THRESHOLDS,
)


class ClassificationThresholdProcessor:
"""
Generic processor for classifications using tag-specific thresholds.
Can be used with any classification system where different classes
need different confidence thresholds.
"""

def __init__(self, thresholds: dict[str, float], default_threshold: float = 0.5):
"""
Initialize the processor with thresholds.

Args:
thresholds: Dictionary of classification tags and their threshold values.
default_threshold: Default threshold to use if tag isn't in thresholds.
"""
self.thresholds = thresholds
self.default_threshold = default_threshold

@classmethod
def for_tdamm(cls):
"""Create a processor for TDAMM classification."""
return cls(TDAMM_TAG_THRESHOLDS, DEFAULT_TDAMM_THRESHOLD)

@classmethod
def for_division(cls):
"""Create a processor for Division classification."""
return cls(DIVISION_TAG_THRESHOLDS, DEFAULT_DIVISION_THRESHOLD)

def get_threshold(self, tag: str) -> float:
"""
Get the threshold for a tag.

Args:
tag: The tag to get threshold for

Returns:
The threshold value as a float
"""
return self.thresholds.get(tag, self.default_threshold)

def filter_classifications(self, classifications: dict[str, float | str]) -> dict[str, float]:
"""
Filter classifications based on their thresholds.

Args:
classifications: Dictionary with classification keys and confidence scores

Returns:
Dictionary with classifications that passed their thresholds
"""
result = {}
for key, confidence in classifications.items():
# Convert confidence to float if it's a string
if isinstance(confidence, str):
try:
confidence_value = float(confidence)
except (ValueError, TypeError):
continue
else:
confidence_value = confidence

# Get the threshold for this classification
threshold = self.get_threshold(key)

# Keep only classifications that meet their threshold
if confidence_value >= threshold:
result[key] = confidence_value

return result
1 change: 0 additions & 1 deletion production.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ services:
image: sde_indexing_helper_production_celerybeat
container_name: sde_indexing_helper_production_celerybeat
depends_on:
- awscli
- postgres
ports: []
command: /start-celerybeat
Expand Down