diff --git a/sde_collections/models/collection.py b/sde_collections/models/collection.py index baa3fe77..52a520d9 100644 --- a/sde_collections/models/collection.py +++ b/sde_collections/models/collection.py @@ -666,9 +666,27 @@ def generate_inference_job(self, classification_type): InferenceJob = apps.get_model("inference", "InferenceJob") ModelVersion = apps.get_model("inference", "ModelVersion") + + try: + model_version = ModelVersion.get_active_version(classification_type) + except ModelVersion.DoesNotExist: + if classification_type == 1: # TDAMM + model_name = "TDAMM" + elif classification_type == 2: # DIVISION + model_name = "DC" + else: + raise ValueError(f"Unsupported classification type: {classification_type}") + + model_version = ModelVersion.objects.create( + api_identifier=model_name, + description=f"{model_name.upper()} classification model", + classification_type=classification_type, + is_active=True, + ) + return InferenceJob.objects.create( collection=self, - model_version=ModelVersion.get_active_version(classification_type), + model_version=model_version, ) def queue_necessary_classifications(self):