From d2d0bdc6b97b3485d436ab024206b5785a431cb8 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Wed, 4 Feb 2026 18:35:54 +0100 Subject: [PATCH 1/3] refactor dojo async task base task The custom decorators that we have on Celery tasks interfere with some (advanced) celery functionality like signatures. This PR refactors this to have a clean base task that passes on context, but does not interfere with celery mechanisms. The logic to decide whether or not the task is to be called asynchronously is now in a dispatch method. --- dojo/api_v2/views.py | 5 +- dojo/celery.py | 52 ++- dojo/celery_dispatch.py | 95 ++++++ dojo/endpoint/views.py | 3 +- dojo/engagement/services.py | 3 +- dojo/engagement/views.py | 3 +- dojo/finding/deduplication.py | 3 - dojo/finding/helper.py | 11 +- dojo/finding/views.py | 5 +- dojo/finding_group/views.py | 5 +- dojo/importers/base_importer.py | 61 +--- dojo/importers/default_importer.py | 4 +- dojo/importers/default_reimporter.py | 4 +- dojo/importers/endpoint_manager.py | 33 +- dojo/jira_link/helper.py | 23 +- dojo/management/commands/dedupe.py | 19 +- dojo/models.py | 8 +- dojo/notifications/helper.py | 107 +++--- dojo/product/helpers.py | 2 - dojo/sla_config/helpers.py | 2 - dojo/tags_signals.py | 3 +- dojo/tasks.py | 19 +- dojo/templatetags/display_tags.py | 4 +- dojo/test/views.py | 3 +- dojo/tools/tool_issue_updater.py | 6 +- dojo/utils.py | 297 ++++++++++------- unittests/test_async_delete.py | 313 ++++++++++++++++++ unittests/test_jira_import_and_pushing_api.py | 2 +- unittests/test_notifications.py | 4 +- 29 files changed, 806 insertions(+), 293 deletions(-) create mode 100644 dojo/celery_dispatch.py create mode 100644 unittests/test_async_delete.py diff --git a/dojo/api_v2/views.py b/dojo/api_v2/views.py index c1a3b12db6f..69d92161dea 100644 --- a/dojo/api_v2/views.py +++ b/dojo/api_v2/views.py @@ -46,6 +46,7 @@ ) from dojo.api_v2.prefetch.prefetcher import _Prefetcher from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.cred.queries import get_authorized_cred_mappings from dojo.endpoint.queries import ( get_authorized_endpoint_status, @@ -679,13 +680,13 @@ def update_jira_epic(self, request, pk=None): try: if engagement.has_jira_issue: - jira_helper.update_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic update query sent"}, status=status.HTTP_200_OK, ) else: - jira_helper.add_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic create query sent"}, status=status.HTTP_200_OK, diff --git a/dojo/celery.py b/dojo/celery.py index ead4a8813a8..3cf09e1bc2c 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -12,16 +12,56 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings") -class PgHistoryTask(Task): +class DojoAsyncTask(Task): + + """ + Base task class that provides dojo_async_task functionality without using a decorator. + + This class: + - Injects user context into task kwargs + - Tracks task calls for performance testing + - Supports all Celery features (signatures, chords, groups, chains) + """ + + def apply_async(self, args=None, kwargs=None, **options): + """Override apply_async to inject user context and track tasks.""" + from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + if kwargs is None: + kwargs = {} + + # Inject user context if not already present + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Control flag used for sync/async decision; never pass into the task itself + kwargs.pop("sync", None) + + # Track dispatch + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + # Call parent to execute async + return super().apply_async(args=args, kwargs=kwargs, **options) + + +class PgHistoryTask(DojoAsyncTask): """ Custom Celery base task that automatically applies pghistory context. - When a task is dispatched via dojo_async_task, the current pghistory - context is captured and passed in kwargs as "_pgh_context". This base - class extracts that context and applies it before running the task, - ensuring all database events share the same context as the original - request. + This class inherits from DojoAsyncTask to provide: + - User context injection and task tracking (from DojoAsyncTask) + - Automatic pghistory context application (from this class) + + When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current + pghistory context is captured and passed in kwargs as "_pgh_context". This base + class extracts that context and applies it before running the task, ensuring all + database events share the same context as the original request. """ def __call__(self, *args, **kwargs): diff --git a/dojo/celery_dispatch.py b/dojo/celery_dispatch.py new file mode 100644 index 00000000000..f4ce0e3241b --- /dev/null +++ b/dojo/celery_dispatch.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, cast + +from celery.canvas import Signature + +if TYPE_CHECKING: + from collections.abc import Mapping + + +class _SupportsSi(Protocol): + def si(self, *args: Any, **kwargs: Any) -> Signature: ... + + +class _SupportsApplyAsync(Protocol): + def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ... + + +def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + result: dict[str, Any] = dict(kwargs or {}) + if "async_user" not in result: + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + result["async_user"] = get_current_user() + return result + + +def _inject_pghistory_context(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + """Capture and inject pghistory context if available.""" + result: dict[str, Any] = dict(kwargs or {}) + if "_pgh_context" not in result: + from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import + + if pgh_context := get_serializable_pghistory_context(): + result["_pgh_context"] = pgh_context + return result + + +def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature: + """ + Build a Celery signature with DefectDojo user context and pghistory context injected. + + - If passed a task, returns `task_or_sig.si(*args, **kwargs)`. + - If passed an existing signature, returns a cloned signature with merged kwargs. + """ + injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) + injected.pop("countdown", None) + + if isinstance(task_or_sig, Signature): + merged_kwargs = {**(task_or_sig.kwargs or {}), **injected} + return task_or_sig.clone(kwargs=merged_kwargs) + + return task_or_sig.si(*args, **injected) + + +def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any: + """ + Dispatch a task/signature using DefectDojo semantics. + + - Inject `async_user` if missing. + - Capture and inject pghistory context if available. + - Respect `sync=True` (foreground execution) and user `block_execution`. + - Support `countdown=` for async dispatch. + + Returns: + - async: AsyncResult-like return from Celery + - sync: underlying return value of the task + + """ + from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import + + countdown = cast("int", kwargs.pop("countdown", 0)) + injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) + + sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected) + sig_kwargs = dict(sig.kwargs or {}) + + if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs): + # DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here. + return sig.apply_async(countdown=countdown) + + # Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior) + dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs) + + sig_kwargs.pop("sync", None) + sig = sig.clone(kwargs=sig_kwargs) + eager = sig.apply() + try: + return eager.get(propagate=True) + except RuntimeError: + # Since we are intentionally running synchronously, we can propagate exceptions directly, and enable sync subtasks + # If the requests desires this. Celery docs explain that this is a rare use case, but we support it _just in case_ + return eager.get(propagate=True, disable_sync_subtasks=False) diff --git a/dojo/endpoint/views.py b/dojo/endpoint/views.py index f66869d35b2..caa48f02757 100644 --- a/dojo/endpoint/views.py +++ b/dojo/endpoint/views.py @@ -18,6 +18,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.queries import get_authorized_endpoints_for_queryset from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups @@ -345,7 +346,7 @@ def endpoint_bulk_update_all(request, pid=None): product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct()) endpoints.delete() for prod in product_calc: - calculate_grade(prod.id) + dojo_dispatch_task(calculate_grade, prod.id) if skipped_endpoint_count > 0: add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.") diff --git a/dojo/engagement/services.py b/dojo/engagement/services.py index cd70af1ea2c..42a7c1c05e4 100644 --- a/dojo/engagement/services.py +++ b/dojo/engagement/services.py @@ -5,6 +5,7 @@ from django.dispatch import receiver import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Engagement logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def close_engagement(eng): eng.save() if jira_helper.get_jira_project(eng): - jira_helper.close_epic(eng.id, push_to_jira=True) + dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True) def reopen_engagement(eng): diff --git a/dojo/engagement/views.py b/dojo/engagement/views.py index 4eb5398cb61..fbe0ca6b496 100644 --- a/dojo/engagement/views.py +++ b/dojo/engagement/views.py @@ -37,6 +37,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import save_endpoints_to_add from dojo.engagement.queries import get_authorized_engagements from dojo.engagement.services import close_engagement, reopen_engagement @@ -392,7 +393,7 @@ def copy_engagement(request, eid): form = DoneForm(request.POST) if form.is_valid(): engagement_copy = engagement.copy() - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index 54517108f25..51e34015b61 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -8,7 +8,6 @@ from django.db.models.query_utils import Q from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Finding, System_Settings logger = logging.getLogger(__name__) @@ -45,13 +44,11 @@ def get_finding_models_for_deduplication(finding_ids): ) -@dojo_async_task @app.task def do_dedupe_finding_task(new_finding_id, *args, **kwargs): return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs) -@dojo_async_task @app.task def do_dedupe_batch_task(finding_ids, *args, **kwargs): """ diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 51c65553742..74bf3ec7279 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -16,7 +16,6 @@ import dojo.jira_link.helper as jira_helper import dojo.risk_acceptance.helper as ra_helper from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( @@ -395,7 +394,6 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group finding_group.findings.add(*findings) -@dojo_async_task @app.task def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed @@ -440,7 +438,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option if product_grading_option: if system_settings.enable_product_grade: - calculate_grade(finding.test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id) else: deduplicationLogger.debug("skipping product grading because it's disabled in system settings") @@ -457,7 +457,6 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option jira_helper.push_to_jira(finding.finding_group) -@dojo_async_task @app.task def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, issue_updater_option=True, push_to_jira=False, user=None, **kwargs): @@ -500,7 +499,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op tool_issue_updater.async_tool_issue_update(finding) if product_grading_option and system_settings.enable_product_grade: - calculate_grade(findings[0].test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id) if push_to_jira: for finding in findings: diff --git a/dojo/finding/views.py b/dojo/finding/views.py index 3269f92902a..20f37713737 100644 --- a/dojo/finding/views.py +++ b/dojo/finding/views.py @@ -38,6 +38,7 @@ user_is_authorized, ) from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( AcceptedFindingFilter, AcceptedFindingFilterWithoutObjectLookups, @@ -1099,7 +1100,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict): product = finding.test.engagement.product finding.delete() # Update the grade of the product async - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) # Add a message to the request that the finding was successfully deleted messages.add_message( request, @@ -1374,7 +1375,7 @@ def copy_finding(request, fid): test = form.cleaned_data.get("test") product = finding.test.engagement.product finding_copy = finding.copy(test=test) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding_group/views.py b/dojo/finding_group/views.py index 451d4dcd720..e29c401b80d 100644 --- a/dojo/finding_group/views.py +++ b/dojo/finding_group/views.py @@ -16,6 +16,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( FindingFilter, FindingFilterWithoutObjectLookups, @@ -100,7 +101,7 @@ def view_finding_group(request, fgid): elif not finding_group.has_jira_issue: jira_helper.finding_group_link_jira(request, finding_group, jira_issue) elif push_to_jira: - jira_helper.push_to_jira(finding_group, sync=True) + dojo_dispatch_task(jira_helper.push_to_jira, finding_group, sync=True) finding_group.save() return HttpResponseRedirect(reverse("view_test", args=(finding_group.test.id,))) @@ -200,7 +201,7 @@ def push_to_jira(request, fgid): # it may look like success here, but the push_to_jira are swallowing exceptions # but cant't change too much now without having a test suite, so leave as is for now with the addition warning message to check alerts for background errors. - if jira_helper.push_to_jira(group, sync=True): + if dojo_dispatch_task(jira_helper.push_to_jira, group, sync=True): messages.add_message( request, messages.SUCCESS, diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 3e00a216cbd..53df8618443 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -3,7 +3,6 @@ import time from collections.abc import Iterable -from celery import chord, group from django.conf import settings from django.core.exceptions import ValidationError from django.core.files.base import ContentFile @@ -14,7 +13,7 @@ import dojo.finding.helper as finding_helper import dojo.risk_acceptance.helper as ra_helper -from dojo import utils +from dojo.celery_dispatch import dojo_dispatch_task from dojo.importers.endpoint_manager import EndpointManager from dojo.importers.location_manager import LocationManager from dojo.importers.options import ImporterOptions @@ -31,7 +30,6 @@ Endpoint, FileUpload, Finding, - System_Settings, Test, Test_Import, Test_Import_Finding_Action, @@ -676,47 +674,6 @@ def update_test_type_from_internal_test(self, internal_test: ParserTest) -> None self.test.test_type.dynamic_tool = dynamic_tool self.test.test_type.save() - def maybe_launch_post_processing_chord( - self, - post_processing_task_signatures, - current_batch_number: int, - max_batch_size: int, - * - is_final_batch: bool, - ) -> tuple[list, int, bool]: - """ - Helper to optionally launch a chord of post-processing tasks with a calculate-grade callback - when async is desired. Uses exponential batch sizing up to the configured max batch size. - - Returns a tuple of (post_processing_task_signatures, current_batch_number, launched) - where launched indicates whether a chord/group was dispatched and signatures were reset. - """ - launched = False - if not post_processing_task_signatures: - return post_processing_task_signatures, current_batch_number, launched - - current_batch_size = min(2 ** current_batch_number, max_batch_size) - batch_full = len(post_processing_task_signatures) >= current_batch_size - - if batch_full or is_final_batch: - product = self.test.engagement.product - system_settings = System_Settings.objects.get() - if system_settings.enable_product_grade: - calculate_grade_signature = utils.calculate_grade.si(product.id) - chord(post_processing_task_signatures)(calculate_grade_signature) - else: - group(post_processing_task_signatures).apply_async() - - logger.debug( - f"Launched chord with {len(post_processing_task_signatures)} tasks (batch #{current_batch_number}, size: {len(post_processing_task_signatures)})", - ) - post_processing_task_signatures = [] - if not is_final_batch: - current_batch_number += 1 - launched = True - - return post_processing_task_signatures, current_batch_number, launched - def verify_tool_configuration_from_test(self): """ Verify that the Tool_Configuration supplied along with the @@ -988,11 +945,23 @@ def mitigate_finding( ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) if settings.V3_FEATURE_LOCATIONS: # Mitigate the location statuses - self.location_manager.mitigate_location_status(finding.locations.all(), self.user, kwuser=self.user, sync=True) + dojo_dispatch_task( + self.location_manager.mitigate_location_status, + finding.locations.all(), + self.user, + kwuser=self.user, + sync=True, + ) else: # TODO: Delete this after the move to Locations # Mitigate the endpoint statuses - self.endpoint_manager.mitigate_endpoint_status(finding.status_finding.all(), self.user, kwuser=self.user, sync=True) + dojo_dispatch_task( + self.endpoint_manager.mitigate_endpoint_status, + finding.status_finding.all(), + self.user, + kwuser=self.user, + sync=True, + ) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 6fc2beff074..a35eaa27496 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -7,6 +7,7 @@ from django.urls import reverse import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.options import ImporterOptions @@ -265,7 +266,8 @@ def process_findings( batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", push_to_jira, len(finding_ids_batch), is_final_finding) - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 863a2cf0212..8d8bcc2d13e 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -7,6 +7,7 @@ import dojo.finding.helper as finding_helper import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.deduplication import ( find_candidates_for_deduplication_hash, find_candidates_for_deduplication_uid_or_hash, @@ -432,7 +433,8 @@ def process_findings( if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 8817ff71bdb..d390067b63c 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -5,7 +5,7 @@ from django.utils import timezone from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import endpoint_get_or_create from dojo.models import ( Dojo_User, @@ -19,17 +19,15 @@ # TODO: Delete this after the move to Locations class EndpointManager: - @dojo_async_task - @app.task() + @app.task def add_endpoints_to_unsaved_finding( - self, - finding: Finding, + finding: Finding, # noqa: N805 endpoints: list[Endpoint], **kwargs: dict, ) -> None: """Creates Endpoint objects for a single finding and creates the link via the endpoint status""" logger.debug(f"IMPORT_SCAN: Adding {len(endpoints)} endpoints to finding: {finding}") - self.clean_unsaved_endpoints(endpoints) + EndpointManager.clean_unsaved_endpoints(endpoints) for endpoint in endpoints: ep = None eps = [] @@ -42,7 +40,8 @@ def add_endpoints_to_unsaved_finding( path=endpoint.path, query=endpoint.query, fragment=endpoint.fragment, - product=finding.test.engagement.product) + product=finding.test.engagement.product, + ) eps.append(ep) except (MultipleObjectsReturned): msg = ( @@ -59,11 +58,9 @@ def add_endpoints_to_unsaved_finding( logger.debug(f"IMPORT_SCAN: {len(endpoints)} endpoints imported") - @dojo_async_task - @app.task() + @app.task def mitigate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 user: Dojo_User, **kwargs: dict, ) -> None: @@ -86,11 +83,9 @@ def mitigate_endpoint_status( batch_size=1000, ) - @dojo_async_task - @app.task() + @app.task def reactivate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 **kwargs: dict, ) -> None: """Reactivate all endpoint status objects that are supplied""" @@ -119,10 +114,10 @@ def chunk_endpoints_and_disperse( endpoints: list[Endpoint], **kwargs: dict, ) -> None: - self.add_endpoints_to_unsaved_finding(finding, endpoints, sync=True) + dojo_dispatch_task(self.add_endpoints_to_unsaved_finding, finding, endpoints, sync=True) + @staticmethod def clean_unsaved_endpoints( - self, endpoints: list[Endpoint], ) -> None: """ @@ -140,7 +135,7 @@ def chunk_endpoints_and_reactivate( endpoint_status_list: list[Endpoint_Status], **kwargs: dict, ) -> None: - self.reactivate_endpoint_status(endpoint_status_list, sync=True) + dojo_dispatch_task(self.reactivate_endpoint_status, endpoint_status_list, sync=True) def chunk_endpoints_and_mitigate( self, @@ -148,7 +143,7 @@ def chunk_endpoints_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - self.mitigate_endpoint_status(endpoint_status_list, user, sync=True) + dojo_dispatch_task(self.mitigate_endpoint_status, endpoint_status_list, user, sync=True) def update_endpoint_status( self, diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index 6f7774c7cc3..7a9ccaa3e96 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -18,7 +18,7 @@ from requests.auth import HTTPBasicAuth from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( Engagement, @@ -763,20 +763,19 @@ def push_to_jira(obj, *args, **kwargs): if isinstance(obj, Finding): if obj.has_finding_group: logger.debug("pushing finding group for %s to JIRA", obj) - return push_finding_group_to_jira(obj.finding_group.id, *args, **kwargs) - return push_finding_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.finding_group.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Finding_Group): - return push_finding_group_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Engagement): - return push_engagement_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_engagement_to_jira, obj.id, *args, **kwargs) logger.error("unsupported object passed to push_to_jira: %s %i %s", obj.__name__, obj.id, obj) return None # we need thre separate celery tasks due to the decorators we're using to map to/from ids -@dojo_async_task @app.task def push_finding_to_jira(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -789,7 +788,6 @@ def push_finding_to_jira(finding_id, *args, **kwargs): return add_jira_issue(finding, *args, **kwargs) -@dojo_async_task @app.task def push_finding_group_to_jira(finding_group_id, *args, **kwargs): finding_group = get_object_or_none(Finding_Group, id=finding_group_id) @@ -806,7 +804,6 @@ def push_finding_group_to_jira(finding_group_id, *args, **kwargs): return add_jira_issue(finding_group, *args, **kwargs) -@dojo_async_task @app.task def push_engagement_to_jira(engagement_id, *args, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -815,8 +812,8 @@ def push_engagement_to_jira(engagement_id, *args, **kwargs): return None if engagement.has_jira_issue: - return update_epic(engagement.id, *args, **kwargs) - return add_epic(engagement.id, *args, **kwargs) + return dojo_dispatch_task(update_epic, engagement.id, *args, **kwargs) + return dojo_dispatch_task(add_epic, engagement.id, *args, **kwargs) def add_issues_to_epic(jira, obj, epic_id, issue_keys, *, ignore_epics=True): @@ -1396,7 +1393,6 @@ def jira_check_attachment(issue, source_file_name): return file_exists -@dojo_async_task @app.task def close_epic(engagement_id, push_to_jira, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1445,7 +1441,6 @@ def close_epic(engagement_id, push_to_jira, **kwargs): return False -@dojo_async_task @app.task def update_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1492,7 +1487,6 @@ def update_epic(engagement_id, **kwargs): return False -@dojo_async_task @app.task def add_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1601,10 +1595,9 @@ def add_comment(obj, note, *, force_push=False, **kwargs): return False # Call the internal task with IDs (runs synchronously within this task) - return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs) + return dojo_dispatch_task(add_comment_internal, jira_issue.id, note.id, force_push=force_push, **kwargs) -@dojo_async_task @app.task def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs): """Internal Celery task that adds a comment to a JIRA issue.""" diff --git a/dojo/management/commands/dedupe.py b/dojo/management/commands/dedupe.py index a4cbee519a6..3d17f4a7fe9 100644 --- a/dojo/management/commands/dedupe.py +++ b/dojo/management/commands/dedupe.py @@ -131,14 +131,25 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup mass_model_updater(Finding, findings, do_dedupe_finding_task_internal, fields=None, order="desc", page_size=100, log_prefix="deduplicating ") else: # async tasks only need the id - mass_model_updater(Finding, findings.only("id"), lambda f: do_dedupe_finding_task(f.id), fields=None, order="desc", log_prefix="deduplicating ") + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + mass_model_updater( + Finding, + findings.only("id"), + lambda f: dojo_dispatch_task(do_dedupe_finding_task, f.id), + fields=None, + order="desc", + log_prefix="deduplicating ", + ) if dedupe_sync: # update the grading (if enabled) and only useful in sync mode # in async mode the background task that grades products every hour will pick it up logger.debug("Updating grades for products...") for product in Product.objects.all(): - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) logger.info("######## Done deduplicating (%s) ########", ("foreground" if dedupe_sync else "tasks submitted to celery")) else: @@ -185,7 +196,9 @@ def _dedupe_batch_mode(self, findings_queryset, *, dedupe_sync: bool = True): else: # Asynchronous: submit task with finding IDs logger.debug(f"Submitting async batch task for {len(batch_finding_ids)} findings for test {test_id}") - do_dedupe_batch_task(batch_finding_ids) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(do_dedupe_batch_task, batch_finding_ids) total_processed += len(batch_finding_ids) batch_finding_ids = [] diff --git a/dojo/models.py b/dojo/models.py index 2281a2bfac8..5f470c83d93 100644 --- a/dojo/models.py +++ b/dojo/models.py @@ -1092,7 +1092,9 @@ def save(self, *args, **kwargs): super(Product, product).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(self, products, severities=severities) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, self, products, severities=severities) def clean(self): sla_days = [self.critical, self.high, self.medium, self.low] @@ -1252,7 +1254,9 @@ def save(self, *args, **kwargs): super(SLA_Configuration, sla_config).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(sla_config, Product.objects.filter(id=self.id)) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, sla_config, Product.objects.filter(id=self.id)) def get_absolute_url(self): return reverse("view_product", args=[str(self.id)]) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index c4458daec01..7318bae91af 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -18,7 +18,8 @@ from dojo import __version__ as dd_version from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task, we_want_async +from dojo.celery_dispatch import dojo_dispatch_task +from dojo.decorators import we_want_async from dojo.labels import get_labels from dojo.models import ( Alerts, @@ -45,6 +46,26 @@ labels = get_labels() +def get_manager_class_instance(): + default_manager = NotificationManager + notification_manager_class = default_manager + if isinstance( + ( + notification_manager := getattr( + settings, + "NOTIFICATION_MANAGER", + default_manager, + ) + ), + str, + ): + with suppress(ModuleNotFoundError): + module_name, _separator, class_name = notification_manager.rpartition(".") + module = importlib.import_module(module_name) + notification_manager_class = getattr(module, class_name) + return notification_manager_class() + + def create_notification( event: str | None = None, title: str | None = None, @@ -62,23 +83,7 @@ def create_notification( **kwargs: dict, ) -> None: """Create an instance of a NotificationManager and dispatch the notification.""" - default_manager = NotificationManager - notification_manager_class = default_manager - if isinstance( - ( - notification_manager := getattr( - settings, - "NOTIFICATION_MANAGER", - default_manager, - ) - ), - str, - ): - with suppress(ModuleNotFoundError): - module_name, _separator, class_name = notification_manager.rpartition(".") - module = importlib.import_module(module_name) - notification_manager_class = getattr(module, class_name) - notification_manager_class().create_notification( + get_manager_class_instance().create_notification( event=event, title=title, finding=finding, @@ -199,8 +204,6 @@ class SlackNotificationManger(NotificationManagerHelpers): """Manger for slack notifications and their helpers.""" - @dojo_async_task - @app.task def send_slack_notification( self, event: str, @@ -317,8 +320,6 @@ class MSTeamsNotificationManger(NotificationManagerHelpers): """Manger for Microsoft Teams notifications and their helpers.""" - @dojo_async_task - @app.task def send_msteams_notification( self, event: str, @@ -368,8 +369,6 @@ class EmailNotificationManger(NotificationManagerHelpers): """Manger for email notifications and their helpers.""" - @dojo_async_task - @app.task def send_mail_notification( self, event: str, @@ -420,8 +419,6 @@ class WebhookNotificationManger(NotificationManagerHelpers): ERROR_PERMANENT = "permanent" ERROR_TEMPORARY = "temporary" - @dojo_async_task - @app.task def send_webhooks_notification( self, event: str, @@ -480,11 +477,7 @@ def send_webhooks_notification( endpoint.first_error = now endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_TMP # In case of failure within one day, endpoint can be deactivated temporally only for one minute - self._webhook_reactivation.apply_async( - args=[self], - kwargs={"endpoint_id": endpoint.pk}, - countdown=60, - ) + webhook_reactivation.apply_async(kwargs={"endpoint_id": endpoint.pk}, countdown=60) # There is no reason to keep endpoint active if it is returning 4xx errors else: endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_PERMANENT @@ -559,7 +552,6 @@ def _test_webhooks_notification(self, endpoint: Notification_Webhooks) -> None: # in "send_webhooks_notification", we are doing deeper analysis, why it failed # for now, "raise_for_status" should be enough - @app.task(ignore_result=True) def _webhook_reactivation(self, endpoint_id: int, **_kwargs: dict): endpoint = Notification_Webhooks.objects.get(pk=endpoint_id) # User already changed status of endpoint @@ -832,9 +824,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Slack Notification") - self._get_manager_instance("slack").send_slack_notification( + dojo_dispatch_task( + send_slack_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -844,9 +837,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending MSTeams Notification") - self._get_manager_instance("msteams").send_msteams_notification( + dojo_dispatch_task( + send_msteams_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -856,9 +850,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Mail Notification") - self._get_manager_instance("mail").send_mail_notification( + dojo_dispatch_task( + send_mail_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -868,13 +863,43 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Webhooks Notification") - self._get_manager_instance("webhooks").send_webhooks_notification( + dojo_dispatch_task( + send_webhooks_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) +@app.task +def send_slack_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("slack").send_slack_notification(event, user=user, **kwargs) + + +@app.task +def send_msteams_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("msteams").send_msteams_notification(event, user=user, **kwargs) + + +@app.task +def send_mail_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("mail").send_mail_notification(event, user=user, **kwargs) + + +@app.task +def send_webhooks_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("webhooks").send_webhooks_notification(event, user=user, **kwargs) + + +@app.task(ignore_result=True) +def webhook_reactivation(endpoint_id: int, **_kwargs: dict) -> None: + get_manager_class_instance()._get_manager_instance("webhooks")._webhook_reactivation(endpoint_id=endpoint_id) + + @app.task(ignore_result=True) def webhook_status_cleanup(*_args: list, **_kwargs: dict): # If some endpoint was affected by some outage (5xx, 429, Timeout) but it was clean during last 24 hours, @@ -902,4 +927,4 @@ def webhook_status_cleanup(*_args: list, **_kwargs: dict): ) for endpoint in broken_endpoints: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=endpoint.pk) + manager._webhook_reactivation(endpoint_id=endpoint.pk) diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index 8247cad4fa8..cdb0750f317 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -5,14 +5,12 @@ from django.db.models import Q from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.location.models import Location from dojo.models import Endpoint, Engagement, Finding, Product, Test logger = logging.getLogger(__name__) -@dojo_async_task @app.task def propagate_tags_on_product(product_id, *args, **kwargs): with contextlib.suppress(Product.DoesNotExist): diff --git a/dojo/sla_config/helpers.py b/dojo/sla_config/helpers.py index da5899a85b0..045456f38d7 100644 --- a/dojo/sla_config/helpers.py +++ b/dojo/sla_config/helpers.py @@ -1,14 +1,12 @@ import logging from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Finding, Product, SLA_Configuration, System_Settings from dojo.utils import get_custom_method, mass_model_updater logger = logging.getLogger(__name__) -@dojo_async_task @app.task def async_update_sla_expiration_dates_sla_config_sync(sla_config: SLA_Configuration, products: list[Product], *args, severities: list[str] | None = None, **kwargs): if method := get_custom_method("FINDING_SLA_EXPIRATION_CALCULATION_METHOD"): diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 58682421d0b..0fea7ae8ad5 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -4,6 +4,7 @@ from django.db.models import signals from django.dispatch import receiver +from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.models import Endpoint, Engagement, Finding, Product, Test from dojo.product import helpers as async_product_funcs @@ -20,7 +21,7 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs): running_async_process = instance.running_async_process # Check if the async process is already running to avoid calling it a second time if not running_async_process and inherit_product_tags(instance): - async_product_funcs.propagate_tags_on_product(instance.id, countdown=5) + dojo_dispatch_task(async_product_funcs.propagate_tags_on_product, instance.id, countdown=5) instance.running_async_process = True diff --git a/dojo/tasks.py b/dojo/tasks.py index b934abc9f02..1bbe104783b 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -2,6 +2,7 @@ from datetime import timedelta import pghistory +from celery import Task from celery.utils.log import get_task_logger from django.apps import apps from django.conf import settings @@ -12,7 +13,7 @@ from dojo.auditlog import run_flush_auditlog from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.helper import fix_loop_duplicates from dojo.location.models import Location from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation @@ -31,7 +32,7 @@ def log_generic_alert(source, title, description): @app.task(bind=True) -def add_alerts(self, runinterval): +def add_alerts(self, runinterval, *args, **kwargs): now = timezone.now() upcoming_engagements = Engagement.objects.filter(target_start__gt=now + timedelta(days=3), target_start__lt=now + timedelta(days=3) + runinterval).order_by("target_start") @@ -73,7 +74,7 @@ def add_alerts(self, runinterval): if system_settings.enable_product_grade: products = Product.objects.all() for product in products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) @app.task(bind=True) @@ -170,11 +171,18 @@ def _async_dupe_delete_impl(): if system_settings.enable_product_grade: logger.info("performing batch product grading for %s products", len(affected_products)) for product in affected_products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) -@app.task(ignore_result=False) +@app.task(ignore_result=False, base=Task) def celery_status(): + """ + Simple health check task to verify Celery is running. + + Uses base Task class (not PgHistoryTask) since it doesn't need: + - User context tracking + - Pghistory context (no database modifications) + """ return True @@ -242,7 +250,6 @@ def clear_sessions(*args, **kwargs): call_command("clearsessions") -@dojo_async_task @app.task def update_watson_search_index_for_model(model_name, pk_list, *args, **kwargs): """ diff --git a/dojo/templatetags/display_tags.py b/dojo/templatetags/display_tags.py index 0fd6b604f06..18c4b7f6f4a 100644 --- a/dojo/templatetags/display_tags.py +++ b/dojo/templatetags/display_tags.py @@ -363,7 +363,9 @@ def product_grade(product): if system_settings.enable_product_grade and product: prod_numeric_grade = product.prod_numeric_grade if not prod_numeric_grade or prod_numeric_grade is None: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) if prod_numeric_grade: if prod_numeric_grade >= system_settings.product_grade_a: grade = "A" diff --git a/dojo/test/views.py b/dojo/test/views.py index 96f7c54e470..c8b52bb6f14 100644 --- a/dojo/test/views.py +++ b/dojo/test/views.py @@ -27,6 +27,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.engagement.queries import get_authorized_engagements from dojo.filters import FindingFilter, FindingFilterWithoutObjectLookups, TemplateFindingFilter, TestImportFilter from dojo.finding.queries import prefetch_for_findings @@ -345,7 +346,7 @@ def copy_test(request, tid): engagement = form.cleaned_data.get("engagement") product = test.engagement.product test_copy = test.copy(engagement=engagement) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 854fb989113..8211e166eed 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -3,7 +3,7 @@ import pghistory from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater @@ -15,7 +15,7 @@ def async_tool_issue_update(finding, *args, **kwargs): if is_tool_issue_updater_needed(finding): - tool_issue_updater(finding.id) + dojo_dispatch_task(tool_issue_updater, finding.id) def is_tool_issue_updater_needed(finding, *args, **kwargs): @@ -23,7 +23,6 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): return test_type.name == SCAN_SONARQUBE_API -@dojo_async_task @app.task def tool_issue_updater(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -37,7 +36,6 @@ def tool_issue_updater(finding_id, *args, **kwargs): SonarQubeApiUpdater().update_sonarqube_finding(finding) -@dojo_async_task @app.task def update_findings_from_source_issues(**kwargs): # Wrap with pghistory context for audit trail diff --git a/dojo/utils.py b/dojo/utils.py index 980cd107659..ba1b5ed0d7c 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -47,7 +47,6 @@ from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.finding.queries import get_authorized_findings from dojo.github import ( add_external_issue_github, @@ -1057,7 +1056,6 @@ def handle_uploaded_selenium(f, cred): cred.save() -@dojo_async_task @app.task def add_external_issue(finding_id, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1073,7 +1071,6 @@ def add_external_issue(finding_id, external_issue_provider, **kwargs): add_external_issue_github(finding, prod, eng) -@dojo_async_task @app.task def update_external_issue(finding_id, old_status, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1088,7 +1085,6 @@ def update_external_issue(finding_id, old_status, external_issue_provider, **kwa update_external_issue_github(finding, prod, eng) -@dojo_async_task @app.task def close_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1103,7 +1099,6 @@ def close_external_issue(finding_id, note, external_issue_provider, **kwargs): close_external_issue_github(finding, note, prod, eng) -@dojo_async_task @app.task def reopen_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1259,7 +1254,6 @@ def grade_product(crit, high, med, low): return max(health, 5) -@dojo_async_task @app.task def calculate_grade(product_id, *args, **kwargs): product = get_object_or_none(Product, id=product_id) @@ -1317,7 +1311,9 @@ def calculate_grade_internal(product, *args, **kwargs): def perform_product_grading(product): system_settings = System_Settings.objects.get() if system_settings.enable_product_grade: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) def get_celery_worker_status(): @@ -2037,130 +2033,187 @@ def is_finding_groups_enabled(): return get_system_setting("enable_finding_groups") -class async_delete: - def __init__(self, *args, **kwargs): - self.mapping = { - "Product_Type": [ - (Endpoint, "product__prod_type__id"), - (Finding, "test__engagement__product__prod_type__id"), - (Test, "engagement__product__prod_type__id"), - (Engagement, "product__prod_type__id"), - (Product, "prod_type__id")], - "Product": [ - (Endpoint, "product__id"), - (Finding, "test__engagement__product__id"), - (Test, "engagement__product__id"), - (Engagement, "product__id")], - "Engagement": [ - (Finding, "test__engagement__id"), - (Test, "engagement__id")], - "Test": [(Finding, "test__id")], - } +# Mapping of object types to their related models for cascading deletes +ASYNC_DELETE_MAPPING = { + "Product_Type": [ + (Endpoint, "product__prod_type__id"), + (Finding, "test__engagement__product__prod_type__id"), + (Test, "engagement__product__prod_type__id"), + (Engagement, "product__prod_type__id"), + (Product, "prod_type__id")], + "Product": [ + (Endpoint, "product__id"), + (Finding, "test__engagement__product__id"), + (Test, "engagement__product__id"), + (Engagement, "product__id")], + "Engagement": [ + (Finding, "test__engagement__id"), + (Test, "engagement__id")], + "Test": [(Finding, "test__id")], +} + + +def _get_object_name(obj): + """Get the class name of an object or model class.""" + if obj.__class__.__name__ == "ModelBase": + return obj.__name__ + return obj.__class__.__name__ - @dojo_async_task - @app.task - def delete_chunk(self, objects, **kwargs): - # Now delete all objects with retry for deadlocks - max_retries = 3 - for obj in objects: - retry_count = 0 - while retry_count < max_retries: - try: - obj.delete() - break # Success, exit retry loop - except OperationalError as e: - error_msg = str(e) - if "deadlock detected" in error_msg.lower(): - retry_count += 1 - if retry_count < max_retries: - # Exponential backoff with jitter - wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 - logger.warning( - f"ASYNC_DELETE: Deadlock detected deleting {self.get_object_name(obj)} {obj.pk}, " - f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", - ) - time.sleep(wait_time) - # Refresh object from DB before retry - obj.refresh_from_db() - else: - logger.error( - f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {self.get_object_name(obj)} {obj.pk}: {e}", - ) - raise + +@app.task +def async_delete_chunk_task(objects, **kwargs): + """ + Module-level Celery task to delete a chunk of objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + max_retries = 3 + for obj in objects: + retry_count = 0 + while retry_count < max_retries: + try: + obj.delete() + break # Success, exit retry loop + except OperationalError as e: + error_msg = str(e) + if "deadlock detected" in error_msg.lower(): + retry_count += 1 + if retry_count < max_retries: + # Exponential backoff with jitter + wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 + logger.warning( + f"ASYNC_DELETE: Deadlock detected deleting {_get_object_name(obj)} {obj.pk}, " + f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", + ) + time.sleep(wait_time) + # Refresh object from DB before retry + obj.refresh_from_db() else: - # Not a deadlock, re-raise + logger.error( + f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {_get_object_name(obj)} {obj.pk}: {e}", + ) raise - except AssertionError: - logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") - # The id must be None - # The object has already been deleted elsewhere - break - except LogEntry.MultipleObjectsReturned: - # Delete the log entrys first, then delete - LogEntry.objects.filter( - content_type=ContentType.objects.get_for_model(obj.__class__), - object_pk=str(obj.pk), - action=LogEntry.Action.DELETE, - ).delete() - # Now delete the object again (no retry needed for this case) - obj.delete() - break - - @dojo_async_task - @app.task - def delete(self, obj, **kwargs): - logger.debug("ASYNC_DELETE: Deleting " + self.get_object_name(obj) + ": " + str(obj)) - model_list = self.mapping.get(self.get_object_name(obj), None) - if model_list: - # The object to be deleted was found in the object list - self.crawl(obj, model_list) - else: - # The object is not supported in async delete, delete normally - logger.debug("ASYNC_DELETE: " + self.get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) - obj.delete() - - @dojo_async_task - @app.task - def crawl(self, obj, model_list, **kwargs): - logger.debug("ASYNC_DELETE: Crawling " + self.get_object_name(obj) + ": " + str(obj)) - with Endpoint.allow_endpoint_init(): # TODO: Delete this after the move to Locations - for model_info in model_list: - task_results = [] - model = model_info[0] - model_query = model_info[1] - filter_dict = {model_query: obj.id} - # Only fetch the IDs since we will make a list of IDs in the following function call - objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") - logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + self.get_object_name(model) + "s in chunks") - chunks = self.chunk_list(model, objects_to_delete) - for chunk in chunks: - logger.debug(f"deleting {len(chunk)} {self.get_object_name(model)}") - result = self.delete_chunk(chunk) - # Collect async task results to wait for them all at once - if hasattr(result, "get"): - task_results.append(result) - # Wait for all chunk deletions to complete (they run in parallel) - for task_result in task_results: - task_result.get(timeout=300) # 5 minute timeout per chunk - # Now delete the main object after all chunks are done - result = self.delete_chunk([obj]) - # Wait for final deletion to complete + else: + # Not a deadlock, re-raise + raise + except AssertionError: + logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") + # The id must be None + # The object has already been deleted elsewhere + break + except LogEntry.MultipleObjectsReturned: + # Delete the log entrys first, then delete + LogEntry.objects.filter( + content_type=ContentType.objects.get_for_model(obj.__class__), + object_pk=str(obj.pk), + action=LogEntry.Action.DELETE, + ).delete() + # Now delete the object again (no retry needed for this case) + obj.delete() + break + + +@app.task +def async_delete_crawl_task(obj, model_list, **kwargs): + """ + Module-level Celery task to crawl and delete related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Crawling " + _get_object_name(obj) + ": " + str(obj)) + for model_info in model_list: + task_results = [] + model = model_info[0] + model_query = model_info[1] + filter_dict = {model_query: obj.id} + # Only fetch the IDs since we will make a list of IDs in the following function call + objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") + logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + _get_object_name(model) + "s in chunks") + chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") + chunks = [objects_to_delete[i:i + chunk_size] for i in range(0, len(objects_to_delete), chunk_size)] + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunks)) + " chunks of " + str(chunk_size)) + for chunk in chunks: + logger.debug(f"deleting {len(chunk)} {_get_object_name(model)}") + result = dojo_dispatch_task(async_delete_chunk_task, list(chunk)) + # Collect async task results to wait for them all at once if hasattr(result, "get"): - result.get(timeout=300) # 5 minute timeout - logger.debug("ASYNC_DELETE: Successfully deleted " + self.get_object_name(obj) + ": " + str(obj)) + task_results.append(result) + # Wait for all chunk deletions to complete (they run in parallel) + for task_result in task_results: + task_result.get(timeout=300) # 5 minute timeout per chunk + # Now delete the main object after all chunks are done + result = dojo_dispatch_task(async_delete_chunk_task, [obj]) + # Wait for final deletion to complete + if hasattr(result, "get"): + result.get(timeout=300) # 5 minute timeout + logger.debug("ASYNC_DELETE: Successfully deleted " + _get_object_name(obj) + ": " + str(obj)) - def chunk_list(self, model, full_list): + +@app.task +def async_delete_task(obj, **kwargs): + """ + Module-level Celery task to delete an object and its related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Deleting " + _get_object_name(obj) + ": " + str(obj)) + model_list = ASYNC_DELETE_MAPPING.get(_get_object_name(obj)) + if model_list: + # The object to be deleted was found in the object list + dojo_dispatch_task(async_delete_crawl_task, obj, model_list) + else: + # The object is not supported in async delete, delete normally + logger.debug("ASYNC_DELETE: " + _get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) + obj.delete() + + +class async_delete: + + """ + Entry point class for async object deletion. + + Usage: + async_del = async_delete() + async_del.delete(instance) + + This class dispatches deletion to module-level Celery tasks via dojo_dispatch_task, + which properly handles user context injection and pghistory context. + """ + + def __init__(self, *args, **kwargs): + # Keep mapping reference for backwards compatibility + self.mapping = ASYNC_DELETE_MAPPING + + def delete(self, obj, **kwargs): + """ + Entry point to delete an object asynchronously. + + Dispatches to async_delete_task via dojo_dispatch_task to ensure proper + handling of async_user and _pgh_context. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_delete_task, obj, **kwargs) + + # Keep helper methods for backwards compatibility and potential direct use + @staticmethod + def get_object_name(obj): + return _get_object_name(obj) + + @staticmethod + def chunk_list(model, full_list): chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") - # Break the list of objects into "chunk_size" lists chunk_list = [full_list[i:i + chunk_size] for i in range(0, len(full_list), chunk_size)] - logger.debug("ASYNC_DELETE: Split " + self.get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) return chunk_list - def get_object_name(self, obj): - if obj.__class__.__name__ == "ModelBase": - return obj.__name__ - return obj.__class__.__name__ - @receiver(user_logged_in) def log_user_login(sender, request, user, **kwargs): diff --git a/unittests/test_async_delete.py b/unittests/test_async_delete.py new file mode 100644 index 00000000000..341723e8296 --- /dev/null +++ b/unittests/test_async_delete.py @@ -0,0 +1,313 @@ +""" +Unit tests for async_delete functionality. + +These tests verify that the async_delete class works correctly with dojo_dispatch_task, +which injects async_user and _pgh_context kwargs into task calls. + +The original bug was that @app.task decorated instance methods didn't properly handle +the injected kwargs, causing TypeError: unexpected keyword argument 'async_user'. +""" +import logging + +from crum import impersonate +from django.contrib.auth.models import User +from django.test import override_settings +from django.utils import timezone + +from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type, UserContactInfo +from dojo.utils import async_delete + +from .dojo_test_case import DojoTestCase + +logger = logging.getLogger(__name__) + + +class TestAsyncDelete(DojoTestCase): + + """ + Test async_delete functionality with dojo_dispatch_task kwargs injection. + + These tests use block_execution=True and crum.impersonate to run tasks synchronously, + which allows errors to surface immediately rather than being lost in background workers. + """ + + def setUp(self): + """Set up test user with block_execution=True and disable unneeded features.""" + super().setUp() + + # Create test user with block_execution=True to run tasks synchronously + self.testuser = User.objects.create( + username="test_async_delete_user", + is_staff=True, + is_superuser=True, + ) + UserContactInfo.objects.create(user=self.testuser, block_execution=True) + + # Log in as the test user (for API client) + self.client.force_login(self.testuser) + + # Disable features that might interfere with deletion + self.system_settings(enable_product_grade=False) + self.system_settings(enable_github=False) + self.system_settings(enable_jira=False) + + # Create base test data + self.product_type = Product_Type.objects.create(name="Test Product Type for Async Delete") + self.test_type = Test_Type.objects.get_or_create(name="Manual Test")[0] + + def tearDown(self): + """Clean up any remaining test data.""" + # Clean up in reverse order of dependencies + Finding.objects.filter(test__engagement__product__prod_type=self.product_type).delete() + Test.objects.filter(engagement__product__prod_type=self.product_type).delete() + Engagement.objects.filter(product__prod_type=self.product_type).delete() + Product.objects.filter(prod_type=self.product_type).delete() + self.product_type.delete() + + super().tearDown() + + def _create_product(self, name="Test Product"): + """Helper to create a product for testing.""" + return Product.objects.create( + name=name, + description="Test product for async delete", + prod_type=self.product_type, + ) + + def _create_engagement(self, product, name="Test Engagement"): + """Helper to create an engagement for testing.""" + return Engagement.objects.create( + name=name, + product=product, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_test(self, engagement, name="Test"): + """Helper to create a test for testing.""" + return Test.objects.create( + engagement=engagement, + test_type=self.test_type, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_finding(self, test, title="Test Finding"): + """Helper to create a finding for testing.""" + return Finding.objects.create( + test=test, + title=title, + severity="High", + description="Test finding for async delete", + mitigation="Test mitigation", + impact="Test impact", + reporter=self.testuser, + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_simple_object(self): + """ + Test that async_delete works for a simple object (Finding). + + Finding is not in the async_delete mapping, so it falls back to direct delete. + This tests that the module-level task accepts **kwargs properly. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # This would raise TypeError before the fix: + # TypeError: delete() got an unexpected keyword argument 'async_user' + async_del = async_delete() + async_del.delete(finding) + + # Verify the finding was deleted + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_test_with_findings(self): + """ + Test that async_delete cascades deletion for Test objects. + + Test is in the async_delete mapping and should cascade delete its findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding1 = self._create_finding(test, "Finding 1") + finding2 = self._create_finding(test, "Finding 2") + + test_pk = test.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the test (should cascade to findings) + async_del = async_delete() + async_del.delete(test) + + # Verify all objects were deleted + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted", + ) + self.assertFalse( + Finding.objects.filter(pk=finding1_pk).exists(), + "Finding 1 should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding2_pk).exists(), + "Finding 2 should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_engagement_with_tests(self): + """ + Test that async_delete cascades deletion for Engagement objects. + + Engagement is in the async_delete mapping and should cascade delete + its tests and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test1 = self._create_test(engagement, "Test 1") + test2 = self._create_test(engagement, "Test 2") + finding1 = self._create_finding(test1, "Finding in Test 1") + finding2 = self._create_finding(test2, "Finding in Test 2") + + engagement_pk = engagement.pk + test1_pk = test1.pk + test2_pk = test2.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the engagement (should cascade to tests and findings) + async_del = async_delete() + async_del.delete(engagement) + + # Verify all objects were deleted + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted", + ) + self.assertFalse( + Test.objects.filter(pk__in=[test1_pk, test2_pk]).exists(), + "Tests should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk__in=[finding1_pk, finding2_pk]).exists(), + "Findings should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_product_with_hierarchy(self): + """ + Test that async_delete cascades deletion for Product objects. + + Product is in the async_delete mapping and should cascade delete + its engagements, tests, and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + + product_pk = product.pk + engagement_pk = engagement.pk + test_pk = test.pk + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the product (should cascade to everything) + async_del = async_delete() + async_del.delete(product) + + # Verify all objects were deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted", + ) + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted via cascade", + ) + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_accepts_sync_kwarg(self): + """ + Test that async_delete passes through the sync kwarg properly. + + The sync=True kwarg forces synchronous execution for the top-level task. + However, nested task dispatches still need user context to run synchronously, + so we use impersonate here as well. + """ + product = self._create_product() + product_pk = product.pk + + # Use impersonate to ensure nested tasks also run synchronously + with impersonate(self.testuser): + # Explicitly pass sync=True + async_del = async_delete() + async_del.delete(product, sync=True) + + # Verify the product was deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted with sync=True", + ) + + def test_async_delete_helper_methods(self): + """ + Test that static helper methods on async_delete class still work. + + These are kept for backwards compatibility. + """ + product = self._create_product() + + # Test get_object_name + self.assertEqual( + async_delete.get_object_name(product), + "Product", + "get_object_name should return class name", + ) + + # Test get_object_name with model class + self.assertEqual( + async_delete.get_object_name(Product), + "Product", + "get_object_name should work with model class", + ) + + def test_async_delete_mapping_preserved(self): + """ + Test that the mapping attribute is preserved on async_delete instances. + + This ensures backwards compatibility for code that might access the mapping. + """ + async_del = async_delete() + + # Verify mapping exists and has expected keys + self.assertIsNotNone(async_del.mapping) + self.assertIn("Product", async_del.mapping) + self.assertIn("Product_Type", async_del.mapping) + self.assertIn("Engagement", async_del.mapping) + self.assertIn("Test", async_del.mapping) diff --git a/unittests/test_jira_import_and_pushing_api.py b/unittests/test_jira_import_and_pushing_api.py index e1a8284698e..eb762a74313 100644 --- a/unittests/test_jira_import_and_pushing_api.py +++ b/unittests/test_jira_import_and_pushing_api.py @@ -981,7 +981,7 @@ def test_engagement_epic_mapping_disabled_no_epic_and_push_findings(self): @patch("dojo.jira_link.helper.can_be_pushed_to_jira", return_value=(True, None, None)) @patch("dojo.jira_link.helper.is_push_all_issues", return_value=False) @patch("dojo.jira_link.helper.push_to_jira", return_value=None) - @patch("dojo.notifications.helper.WebhookNotificationManger.send_webhooks_notification") + @patch("dojo.notifications.helper.send_webhooks_notification") def test_bulk_edit_mixed_findings_and_groups_jira_push_bug(self, mock_webhooks, mock_push_to_jira, mock_is_push_all_issues, mock_can_be_pushed): """ Test the bug in bulk edit: when bulk editing findings where some are in groups diff --git a/unittests/test_notifications.py b/unittests/test_notifications.py index fc547f526a1..2e897702fc4 100644 --- a/unittests/test_notifications.py +++ b/unittests/test_notifications.py @@ -680,7 +680,7 @@ def test_webhook_reactivation(self): with self.subTest("active"): wh = Notification_Webhooks.objects.filter(owner=None).first() manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE) @@ -699,7 +699,7 @@ def test_webhook_reactivation(self): with self.assertLogs("dojo.notifications.helper", level="DEBUG") as cm: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE_TMP) From d8661ab4c33d48ac6f15e7f7390b6c995bbad6ad Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Wed, 4 Feb 2026 19:02:52 +0100 Subject: [PATCH 2/3] Fix AttributeError in celery task dispatch - Use class reference instead of self for task dispatch (self.method returns bound method without .si() attribute) - Update location_manager.py to use dojo_dispatch_task instead of @dojo_async_task decorator - Convert task methods to static-like functions (no self parameter) --- dojo/importers/endpoint_manager.py | 6 ++--- dojo/importers/location_manager.py | 35 +++++++++++++----------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index d390067b63c..3f8c3ec817e 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -114,7 +114,7 @@ def chunk_endpoints_and_disperse( endpoints: list[Endpoint], **kwargs: dict, ) -> None: - dojo_dispatch_task(self.add_endpoints_to_unsaved_finding, finding, endpoints, sync=True) + dojo_dispatch_task(EndpointManager.add_endpoints_to_unsaved_finding, finding, endpoints, sync=True) @staticmethod def clean_unsaved_endpoints( @@ -135,7 +135,7 @@ def chunk_endpoints_and_reactivate( endpoint_status_list: list[Endpoint_Status], **kwargs: dict, ) -> None: - dojo_dispatch_task(self.reactivate_endpoint_status, endpoint_status_list, sync=True) + dojo_dispatch_task(EndpointManager.reactivate_endpoint_status, endpoint_status_list, sync=True) def chunk_endpoints_and_mitigate( self, @@ -143,7 +143,7 @@ def chunk_endpoints_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - dojo_dispatch_task(self.mitigate_endpoint_status, endpoint_status_list, user, sync=True) + dojo_dispatch_task(EndpointManager.mitigate_endpoint_status, endpoint_status_list, user, sync=True) def update_endpoint_status( self, diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index c3a12fb5391..82a08c6204d 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -6,7 +6,7 @@ from django.utils import timezone from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import AbstractLocation, LocationFindingReference from dojo.location.status import FindingLocationStatus from dojo.models import ( @@ -24,17 +24,16 @@ # test_notifications.py: Implement Locations class LocationManager: - def get_or_create_location(self, unsaved_location: AbstractLocation) -> AbstractLocation | None: + @staticmethod + def get_or_create_location(unsaved_location: AbstractLocation) -> AbstractLocation | None: if isinstance(unsaved_location, URL): return URL.get_or_create_from_object(unsaved_location) logger.debug(f"IMPORT_SCAN: Unsupported location type: {type(unsaved_location)}") return None - @dojo_async_task - @app.task() + @app.task def add_locations_to_unsaved_finding( - self, - finding: Finding, + finding: Finding, # noqa: N805 locations: list[AbstractLocation], **kwargs: dict, ) -> None: @@ -42,23 +41,21 @@ def add_locations_to_unsaved_finding( locations = list(set(locations)) logger.debug(f"IMPORT_SCAN: Adding {len(locations)} locations to finding: {finding}") - self.clean_unsaved_locations(locations) + LocationManager.clean_unsaved_locations(locations) # LOCATION LOCATION LOCATION # TODO: bulk create the finding/product refs... locations_saved = 0 for unsaved_location in locations: - if saved_location := self.get_or_create_location(unsaved_location): + if saved_location := LocationManager.get_or_create_location(unsaved_location): locations_saved += 1 saved_location.location.associate_with_finding(finding, status=FindingLocationStatus.Active) logger.debug(f"IMPORT_SCAN: {locations_saved} locations imported") - @dojo_async_task - @app.task() + @app.task def mitigate_location_status( - self, - location_refs: QuerySet[LocationFindingReference], + location_refs: QuerySet[LocationFindingReference], # noqa: N805 user: Dojo_User, **kwargs: dict, ) -> None: @@ -69,11 +66,9 @@ def mitigate_location_status( status=FindingLocationStatus.Mitigated, ) - @dojo_async_task - @app.task() + @app.task def reactivate_location_status( - self, - location_refs: QuerySet[LocationFindingReference], + location_refs: QuerySet[LocationFindingReference], # noqa: N805 **kwargs: dict, ) -> None: """Reactivate all given (mitigated) locations refs""" @@ -89,10 +84,10 @@ def chunk_locations_and_disperse( locations: list[AbstractLocation], **kwargs: dict, ) -> None: - self.add_locations_to_unsaved_finding(finding, locations, sync=True) + dojo_dispatch_task(LocationManager.add_locations_to_unsaved_finding, finding, locations, sync=True) + @staticmethod def clean_unsaved_locations( - self, locations: list[AbstractLocation], ) -> None: """ @@ -110,7 +105,7 @@ def chunk_locations_and_reactivate( location_refs: QuerySet[LocationFindingReference], **kwargs: dict, ) -> None: - self.reactivate_location_status(location_refs, sync=True) + dojo_dispatch_task(LocationManager.reactivate_location_status, location_refs, sync=True) def chunk_locations_and_mitigate( self, @@ -118,7 +113,7 @@ def chunk_locations_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - self.mitigate_location_status(location_refs, user, sync=True) + dojo_dispatch_task(LocationManager.mitigate_location_status, location_refs, user, sync=True) def update_location_status( self, From 0e781ce9313a59784134b290fd5907f367a90064 Mon Sep 17 00:00:00 2001 From: Valentijn Scholten Date: Wed, 4 Feb 2026 19:44:10 +0100 Subject: [PATCH 3/3] Fix remaining dojo_dispatch_task call sites - Avoid passing manager/task attributes via instance (use class task objects to ensure .si() is available) - Stop dispatching non-task jira_helper.push_to_jira through dojo_dispatch_task; call it directly and let it dispatch the underlying celery tasks --- dojo/finding_group/views.py | 5 ++--- dojo/importers/base_importer.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/dojo/finding_group/views.py b/dojo/finding_group/views.py index e29c401b80d..451d4dcd720 100644 --- a/dojo/finding_group/views.py +++ b/dojo/finding_group/views.py @@ -16,7 +16,6 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions -from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( FindingFilter, FindingFilterWithoutObjectLookups, @@ -101,7 +100,7 @@ def view_finding_group(request, fgid): elif not finding_group.has_jira_issue: jira_helper.finding_group_link_jira(request, finding_group, jira_issue) elif push_to_jira: - dojo_dispatch_task(jira_helper.push_to_jira, finding_group, sync=True) + jira_helper.push_to_jira(finding_group, sync=True) finding_group.save() return HttpResponseRedirect(reverse("view_test", args=(finding_group.test.id,))) @@ -201,7 +200,7 @@ def push_to_jira(request, fgid): # it may look like success here, but the push_to_jira are swallowing exceptions # but cant't change too much now without having a test suite, so leave as is for now with the addition warning message to check alerts for background errors. - if dojo_dispatch_task(jira_helper.push_to_jira, group, sync=True): + if jira_helper.push_to_jira(group, sync=True): messages.add_message( request, messages.SUCCESS, diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 53df8618443..c79af314f04 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -946,7 +946,7 @@ def mitigate_finding( if settings.V3_FEATURE_LOCATIONS: # Mitigate the location statuses dojo_dispatch_task( - self.location_manager.mitigate_location_status, + LocationManager.mitigate_location_status, finding.locations.all(), self.user, kwuser=self.user, @@ -956,7 +956,7 @@ def mitigate_finding( # TODO: Delete this after the move to Locations # Mitigate the endpoint statuses dojo_dispatch_task( - self.endpoint_manager.mitigate_endpoint_status, + EndpointManager.mitigate_endpoint_status, finding.status_finding.all(), self.user, kwuser=self.user,