diff --git a/dojo/api_v2/views.py b/dojo/api_v2/views.py index c1a3b12db6f..69d92161dea 100644 --- a/dojo/api_v2/views.py +++ b/dojo/api_v2/views.py @@ -46,6 +46,7 @@ ) from dojo.api_v2.prefetch.prefetcher import _Prefetcher from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.cred.queries import get_authorized_cred_mappings from dojo.endpoint.queries import ( get_authorized_endpoint_status, @@ -679,13 +680,13 @@ def update_jira_epic(self, request, pk=None): try: if engagement.has_jira_issue: - jira_helper.update_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.update_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic update query sent"}, status=status.HTTP_200_OK, ) else: - jira_helper.add_epic(engagement.id, **request.data) + dojo_dispatch_task(jira_helper.add_epic, engagement.id, **request.data) response = Response( {"info": "Jira Epic create query sent"}, status=status.HTTP_200_OK, diff --git a/dojo/celery.py b/dojo/celery.py index ead4a8813a8..3cf09e1bc2c 100644 --- a/dojo/celery.py +++ b/dojo/celery.py @@ -12,16 +12,56 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "dojo.settings.settings") -class PgHistoryTask(Task): +class DojoAsyncTask(Task): + + """ + Base task class that provides dojo_async_task functionality without using a decorator. + + This class: + - Injects user context into task kwargs + - Tracks task calls for performance testing + - Supports all Celery features (signatures, chords, groups, chains) + """ + + def apply_async(self, args=None, kwargs=None, **options): + """Override apply_async to inject user context and track tasks.""" + from dojo.decorators import dojo_async_task_counter # noqa: PLC0415 circular import + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + if kwargs is None: + kwargs = {} + + # Inject user context if not already present + if "async_user" not in kwargs: + kwargs["async_user"] = get_current_user() + + # Control flag used for sync/async decision; never pass into the task itself + kwargs.pop("sync", None) + + # Track dispatch + dojo_async_task_counter.incr( + self.name, + args=args, + kwargs=kwargs, + ) + + # Call parent to execute async + return super().apply_async(args=args, kwargs=kwargs, **options) + + +class PgHistoryTask(DojoAsyncTask): """ Custom Celery base task that automatically applies pghistory context. - When a task is dispatched via dojo_async_task, the current pghistory - context is captured and passed in kwargs as "_pgh_context". This base - class extracts that context and applies it before running the task, - ensuring all database events share the same context as the original - request. + This class inherits from DojoAsyncTask to provide: + - User context injection and task tracking (from DojoAsyncTask) + - Automatic pghistory context application (from this class) + + When a task is dispatched via dojo_dispatch_task or dojo_async_task, the current + pghistory context is captured and passed in kwargs as "_pgh_context". This base + class extracts that context and applies it before running the task, ensuring all + database events share the same context as the original request. """ def __call__(self, *args, **kwargs): diff --git a/dojo/celery_dispatch.py b/dojo/celery_dispatch.py new file mode 100644 index 00000000000..f4ce0e3241b --- /dev/null +++ b/dojo/celery_dispatch.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, cast + +from celery.canvas import Signature + +if TYPE_CHECKING: + from collections.abc import Mapping + + +class _SupportsSi(Protocol): + def si(self, *args: Any, **kwargs: Any) -> Signature: ... + + +class _SupportsApplyAsync(Protocol): + def apply_async(self, args: Any | None = None, kwargs: Any | None = None, **options: Any) -> Any: ... + + +def _inject_async_user(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + result: dict[str, Any] = dict(kwargs or {}) + if "async_user" not in result: + from dojo.utils import get_current_user # noqa: PLC0415 circular import + + result["async_user"] = get_current_user() + return result + + +def _inject_pghistory_context(kwargs: Mapping[str, Any] | None) -> dict[str, Any]: + """Capture and inject pghistory context if available.""" + result: dict[str, Any] = dict(kwargs or {}) + if "_pgh_context" not in result: + from dojo.pghistory_utils import get_serializable_pghistory_context # noqa: PLC0415 circular import + + if pgh_context := get_serializable_pghistory_context(): + result["_pgh_context"] = pgh_context + return result + + +def dojo_create_signature(task_or_sig: _SupportsSi | Signature, *args: Any, **kwargs: Any) -> Signature: + """ + Build a Celery signature with DefectDojo user context and pghistory context injected. + + - If passed a task, returns `task_or_sig.si(*args, **kwargs)`. + - If passed an existing signature, returns a cloned signature with merged kwargs. + """ + injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) + injected.pop("countdown", None) + + if isinstance(task_or_sig, Signature): + merged_kwargs = {**(task_or_sig.kwargs or {}), **injected} + return task_or_sig.clone(kwargs=merged_kwargs) + + return task_or_sig.si(*args, **injected) + + +def dojo_dispatch_task(task_or_sig: _SupportsSi | _SupportsApplyAsync | Signature, *args: Any, **kwargs: Any) -> Any: + """ + Dispatch a task/signature using DefectDojo semantics. + + - Inject `async_user` if missing. + - Capture and inject pghistory context if available. + - Respect `sync=True` (foreground execution) and user `block_execution`. + - Support `countdown=` for async dispatch. + + Returns: + - async: AsyncResult-like return from Celery + - sync: underlying return value of the task + + """ + from dojo.decorators import dojo_async_task_counter, we_want_async # noqa: PLC0415 circular import + + countdown = cast("int", kwargs.pop("countdown", 0)) + injected = _inject_async_user(kwargs) + injected = _inject_pghistory_context(injected) + + sig = dojo_create_signature(task_or_sig if isinstance(task_or_sig, Signature) else cast("_SupportsSi", task_or_sig), *args, **injected) + sig_kwargs = dict(sig.kwargs or {}) + + if we_want_async(*sig.args, func=getattr(sig, "type", None), **sig_kwargs): + # DojoAsyncTask.apply_async tracks async dispatch. Avoid double-counting here. + return sig.apply_async(countdown=countdown) + + # Track foreground execution as a "created task" as well (matches historical dojo_async_task behavior) + dojo_async_task_counter.incr(str(sig.task), args=sig.args, kwargs=sig_kwargs) + + sig_kwargs.pop("sync", None) + sig = sig.clone(kwargs=sig_kwargs) + eager = sig.apply() + try: + return eager.get(propagate=True) + except RuntimeError: + # Since we are intentionally running synchronously, we can propagate exceptions directly, and enable sync subtasks + # If the requests desires this. Celery docs explain that this is a rare use case, but we support it _just in case_ + return eager.get(propagate=True, disable_sync_subtasks=False) diff --git a/dojo/endpoint/views.py b/dojo/endpoint/views.py index f66869d35b2..caa48f02757 100644 --- a/dojo/endpoint/views.py +++ b/dojo/endpoint/views.py @@ -18,6 +18,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.queries import get_authorized_endpoints_for_queryset from dojo.endpoint.utils import clean_hosts_run, endpoint_meta_import from dojo.filters import EndpointFilter, EndpointFilterWithoutObjectLookups @@ -345,7 +346,7 @@ def endpoint_bulk_update_all(request, pid=None): product_calc = list(Product.objects.filter(endpoint__id__in=endpoints_to_update).distinct()) endpoints.delete() for prod in product_calc: - calculate_grade(prod.id) + dojo_dispatch_task(calculate_grade, prod.id) if skipped_endpoint_count > 0: add_error_message_to_response(f"Skipped deletion of {skipped_endpoint_count} endpoints because you are not authorized.") diff --git a/dojo/engagement/services.py b/dojo/engagement/services.py index cd70af1ea2c..42a7c1c05e4 100644 --- a/dojo/engagement/services.py +++ b/dojo/engagement/services.py @@ -5,6 +5,7 @@ from django.dispatch import receiver import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Engagement logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ def close_engagement(eng): eng.save() if jira_helper.get_jira_project(eng): - jira_helper.close_epic(eng.id, push_to_jira=True) + dojo_dispatch_task(jira_helper.close_epic, eng.id, push_to_jira=True) def reopen_engagement(eng): diff --git a/dojo/engagement/views.py b/dojo/engagement/views.py index 4eb5398cb61..fbe0ca6b496 100644 --- a/dojo/engagement/views.py +++ b/dojo/engagement/views.py @@ -37,6 +37,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import save_endpoints_to_add from dojo.engagement.queries import get_authorized_engagements from dojo.engagement.services import close_engagement, reopen_engagement @@ -392,7 +393,7 @@ def copy_engagement(request, eid): form = DoneForm(request.POST) if form.is_valid(): engagement_copy = engagement.copy() - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/finding/deduplication.py b/dojo/finding/deduplication.py index 54517108f25..51e34015b61 100644 --- a/dojo/finding/deduplication.py +++ b/dojo/finding/deduplication.py @@ -8,7 +8,6 @@ from django.db.models.query_utils import Q from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Finding, System_Settings logger = logging.getLogger(__name__) @@ -45,13 +44,11 @@ def get_finding_models_for_deduplication(finding_ids): ) -@dojo_async_task @app.task def do_dedupe_finding_task(new_finding_id, *args, **kwargs): return do_dedupe_finding_task_internal(Finding.objects.get(id=new_finding_id), *args, **kwargs) -@dojo_async_task @app.task def do_dedupe_batch_task(finding_ids, *args, **kwargs): """ diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 51c65553742..74bf3ec7279 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -16,7 +16,6 @@ import dojo.jira_link.helper as jira_helper import dojo.risk_acceptance.helper as ra_helper from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.endpoint.utils import endpoint_get_or_create, save_endpoints_to_add from dojo.file_uploads.helper import delete_related_files from dojo.finding.deduplication import ( @@ -395,7 +394,6 @@ def add_findings_to_auto_group(name, findings, group_by, *, create_finding_group finding_group.findings.add(*findings) -@dojo_async_task @app.task def post_process_finding_save(finding_id, dedupe_option=True, rules_option=True, product_grading_option=True, # noqa: FBT002 issue_updater_option=True, push_to_jira=False, user=None, *args, **kwargs): # noqa: FBT002 - this is bit hard to fix nice have this universally fixed @@ -440,7 +438,9 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option if product_grading_option: if system_settings.enable_product_grade: - calculate_grade(finding.test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, finding.test.engagement.product.id) else: deduplicationLogger.debug("skipping product grading because it's disabled in system settings") @@ -457,7 +457,6 @@ def post_process_finding_save_internal(finding, dedupe_option=True, rules_option jira_helper.push_to_jira(finding.finding_group) -@dojo_async_task @app.task def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_option=True, product_grading_option=True, issue_updater_option=True, push_to_jira=False, user=None, **kwargs): @@ -500,7 +499,9 @@ def post_process_findings_batch(finding_ids, *args, dedupe_option=True, rules_op tool_issue_updater.async_tool_issue_update(finding) if product_grading_option and system_settings.enable_product_grade: - calculate_grade(findings[0].test.engagement.product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, findings[0].test.engagement.product.id) if push_to_jira: for finding in findings: diff --git a/dojo/finding/views.py b/dojo/finding/views.py index 3269f92902a..20f37713737 100644 --- a/dojo/finding/views.py +++ b/dojo/finding/views.py @@ -38,6 +38,7 @@ user_is_authorized, ) from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.filters import ( AcceptedFindingFilter, AcceptedFindingFilterWithoutObjectLookups, @@ -1099,7 +1100,7 @@ def process_form(self, request: HttpRequest, finding: Finding, context: dict): product = finding.test.engagement.product finding.delete() # Update the grade of the product async - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) # Add a message to the request that the finding was successfully deleted messages.add_message( request, @@ -1374,7 +1375,7 @@ def copy_finding(request, fid): test = form.cleaned_data.get("test") product = finding.test.engagement.product finding_copy = finding.copy(test=test) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index 3e00a216cbd..c79af314f04 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -3,7 +3,6 @@ import time from collections.abc import Iterable -from celery import chord, group from django.conf import settings from django.core.exceptions import ValidationError from django.core.files.base import ContentFile @@ -14,7 +13,7 @@ import dojo.finding.helper as finding_helper import dojo.risk_acceptance.helper as ra_helper -from dojo import utils +from dojo.celery_dispatch import dojo_dispatch_task from dojo.importers.endpoint_manager import EndpointManager from dojo.importers.location_manager import LocationManager from dojo.importers.options import ImporterOptions @@ -31,7 +30,6 @@ Endpoint, FileUpload, Finding, - System_Settings, Test, Test_Import, Test_Import_Finding_Action, @@ -676,47 +674,6 @@ def update_test_type_from_internal_test(self, internal_test: ParserTest) -> None self.test.test_type.dynamic_tool = dynamic_tool self.test.test_type.save() - def maybe_launch_post_processing_chord( - self, - post_processing_task_signatures, - current_batch_number: int, - max_batch_size: int, - * - is_final_batch: bool, - ) -> tuple[list, int, bool]: - """ - Helper to optionally launch a chord of post-processing tasks with a calculate-grade callback - when async is desired. Uses exponential batch sizing up to the configured max batch size. - - Returns a tuple of (post_processing_task_signatures, current_batch_number, launched) - where launched indicates whether a chord/group was dispatched and signatures were reset. - """ - launched = False - if not post_processing_task_signatures: - return post_processing_task_signatures, current_batch_number, launched - - current_batch_size = min(2 ** current_batch_number, max_batch_size) - batch_full = len(post_processing_task_signatures) >= current_batch_size - - if batch_full or is_final_batch: - product = self.test.engagement.product - system_settings = System_Settings.objects.get() - if system_settings.enable_product_grade: - calculate_grade_signature = utils.calculate_grade.si(product.id) - chord(post_processing_task_signatures)(calculate_grade_signature) - else: - group(post_processing_task_signatures).apply_async() - - logger.debug( - f"Launched chord with {len(post_processing_task_signatures)} tasks (batch #{current_batch_number}, size: {len(post_processing_task_signatures)})", - ) - post_processing_task_signatures = [] - if not is_final_batch: - current_batch_number += 1 - launched = True - - return post_processing_task_signatures, current_batch_number, launched - def verify_tool_configuration_from_test(self): """ Verify that the Tool_Configuration supplied along with the @@ -988,11 +945,23 @@ def mitigate_finding( ra_helper.risk_unaccept(self.user, finding, perform_save=False, post_comments=False) if settings.V3_FEATURE_LOCATIONS: # Mitigate the location statuses - self.location_manager.mitigate_location_status(finding.locations.all(), self.user, kwuser=self.user, sync=True) + dojo_dispatch_task( + LocationManager.mitigate_location_status, + finding.locations.all(), + self.user, + kwuser=self.user, + sync=True, + ) else: # TODO: Delete this after the move to Locations # Mitigate the endpoint statuses - self.endpoint_manager.mitigate_endpoint_status(finding.status_finding.all(), self.user, kwuser=self.user, sync=True) + dojo_dispatch_task( + EndpointManager.mitigate_endpoint_status, + finding.status_finding.all(), + self.user, + kwuser=self.user, + sync=True, + ) # to avoid pushing a finding group multiple times, we push those outside of the loop if finding_groups_enabled and finding.finding_group: # don't try to dedupe findings that we are closing diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 6fc2beff074..a35eaa27496 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -7,6 +7,7 @@ from django.urls import reverse import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser from dojo.importers.options import ImporterOptions @@ -265,7 +266,8 @@ def process_findings( batch_finding_ids.clear() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", push_to_jira, len(finding_ids_batch), is_final_finding) - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 863a2cf0212..8d8bcc2d13e 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -7,6 +7,7 @@ import dojo.finding.helper as finding_helper import dojo.jira_link.helper as jira_helper +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.deduplication import ( find_candidates_for_deduplication_hash, find_candidates_for_deduplication_uid_or_hash, @@ -432,7 +433,8 @@ def process_findings( if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() - finding_helper.post_process_findings_batch( + dojo_dispatch_task( + finding_helper.post_process_findings_batch, finding_ids_batch, dedupe_option=True, rules_option=True, diff --git a/dojo/importers/endpoint_manager.py b/dojo/importers/endpoint_manager.py index 8817ff71bdb..3f8c3ec817e 100644 --- a/dojo/importers/endpoint_manager.py +++ b/dojo/importers/endpoint_manager.py @@ -5,7 +5,7 @@ from django.utils import timezone from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.endpoint.utils import endpoint_get_or_create from dojo.models import ( Dojo_User, @@ -19,17 +19,15 @@ # TODO: Delete this after the move to Locations class EndpointManager: - @dojo_async_task - @app.task() + @app.task def add_endpoints_to_unsaved_finding( - self, - finding: Finding, + finding: Finding, # noqa: N805 endpoints: list[Endpoint], **kwargs: dict, ) -> None: """Creates Endpoint objects for a single finding and creates the link via the endpoint status""" logger.debug(f"IMPORT_SCAN: Adding {len(endpoints)} endpoints to finding: {finding}") - self.clean_unsaved_endpoints(endpoints) + EndpointManager.clean_unsaved_endpoints(endpoints) for endpoint in endpoints: ep = None eps = [] @@ -42,7 +40,8 @@ def add_endpoints_to_unsaved_finding( path=endpoint.path, query=endpoint.query, fragment=endpoint.fragment, - product=finding.test.engagement.product) + product=finding.test.engagement.product, + ) eps.append(ep) except (MultipleObjectsReturned): msg = ( @@ -59,11 +58,9 @@ def add_endpoints_to_unsaved_finding( logger.debug(f"IMPORT_SCAN: {len(endpoints)} endpoints imported") - @dojo_async_task - @app.task() + @app.task def mitigate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 user: Dojo_User, **kwargs: dict, ) -> None: @@ -86,11 +83,9 @@ def mitigate_endpoint_status( batch_size=1000, ) - @dojo_async_task - @app.task() + @app.task def reactivate_endpoint_status( - self, - endpoint_status_list: list[Endpoint_Status], + endpoint_status_list: list[Endpoint_Status], # noqa: N805 **kwargs: dict, ) -> None: """Reactivate all endpoint status objects that are supplied""" @@ -119,10 +114,10 @@ def chunk_endpoints_and_disperse( endpoints: list[Endpoint], **kwargs: dict, ) -> None: - self.add_endpoints_to_unsaved_finding(finding, endpoints, sync=True) + dojo_dispatch_task(EndpointManager.add_endpoints_to_unsaved_finding, finding, endpoints, sync=True) + @staticmethod def clean_unsaved_endpoints( - self, endpoints: list[Endpoint], ) -> None: """ @@ -140,7 +135,7 @@ def chunk_endpoints_and_reactivate( endpoint_status_list: list[Endpoint_Status], **kwargs: dict, ) -> None: - self.reactivate_endpoint_status(endpoint_status_list, sync=True) + dojo_dispatch_task(EndpointManager.reactivate_endpoint_status, endpoint_status_list, sync=True) def chunk_endpoints_and_mitigate( self, @@ -148,7 +143,7 @@ def chunk_endpoints_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - self.mitigate_endpoint_status(endpoint_status_list, user, sync=True) + dojo_dispatch_task(EndpointManager.mitigate_endpoint_status, endpoint_status_list, user, sync=True) def update_endpoint_status( self, diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index c3a12fb5391..82a08c6204d 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -6,7 +6,7 @@ from django.utils import timezone from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import AbstractLocation, LocationFindingReference from dojo.location.status import FindingLocationStatus from dojo.models import ( @@ -24,17 +24,16 @@ # test_notifications.py: Implement Locations class LocationManager: - def get_or_create_location(self, unsaved_location: AbstractLocation) -> AbstractLocation | None: + @staticmethod + def get_or_create_location(unsaved_location: AbstractLocation) -> AbstractLocation | None: if isinstance(unsaved_location, URL): return URL.get_or_create_from_object(unsaved_location) logger.debug(f"IMPORT_SCAN: Unsupported location type: {type(unsaved_location)}") return None - @dojo_async_task - @app.task() + @app.task def add_locations_to_unsaved_finding( - self, - finding: Finding, + finding: Finding, # noqa: N805 locations: list[AbstractLocation], **kwargs: dict, ) -> None: @@ -42,23 +41,21 @@ def add_locations_to_unsaved_finding( locations = list(set(locations)) logger.debug(f"IMPORT_SCAN: Adding {len(locations)} locations to finding: {finding}") - self.clean_unsaved_locations(locations) + LocationManager.clean_unsaved_locations(locations) # LOCATION LOCATION LOCATION # TODO: bulk create the finding/product refs... locations_saved = 0 for unsaved_location in locations: - if saved_location := self.get_or_create_location(unsaved_location): + if saved_location := LocationManager.get_or_create_location(unsaved_location): locations_saved += 1 saved_location.location.associate_with_finding(finding, status=FindingLocationStatus.Active) logger.debug(f"IMPORT_SCAN: {locations_saved} locations imported") - @dojo_async_task - @app.task() + @app.task def mitigate_location_status( - self, - location_refs: QuerySet[LocationFindingReference], + location_refs: QuerySet[LocationFindingReference], # noqa: N805 user: Dojo_User, **kwargs: dict, ) -> None: @@ -69,11 +66,9 @@ def mitigate_location_status( status=FindingLocationStatus.Mitigated, ) - @dojo_async_task - @app.task() + @app.task def reactivate_location_status( - self, - location_refs: QuerySet[LocationFindingReference], + location_refs: QuerySet[LocationFindingReference], # noqa: N805 **kwargs: dict, ) -> None: """Reactivate all given (mitigated) locations refs""" @@ -89,10 +84,10 @@ def chunk_locations_and_disperse( locations: list[AbstractLocation], **kwargs: dict, ) -> None: - self.add_locations_to_unsaved_finding(finding, locations, sync=True) + dojo_dispatch_task(LocationManager.add_locations_to_unsaved_finding, finding, locations, sync=True) + @staticmethod def clean_unsaved_locations( - self, locations: list[AbstractLocation], ) -> None: """ @@ -110,7 +105,7 @@ def chunk_locations_and_reactivate( location_refs: QuerySet[LocationFindingReference], **kwargs: dict, ) -> None: - self.reactivate_location_status(location_refs, sync=True) + dojo_dispatch_task(LocationManager.reactivate_location_status, location_refs, sync=True) def chunk_locations_and_mitigate( self, @@ -118,7 +113,7 @@ def chunk_locations_and_mitigate( user: Dojo_User, **kwargs: dict, ) -> None: - self.mitigate_location_status(location_refs, user, sync=True) + dojo_dispatch_task(LocationManager.mitigate_location_status, location_refs, user, sync=True) def update_location_status( self, diff --git a/dojo/jira_link/helper.py b/dojo/jira_link/helper.py index 6f7774c7cc3..7a9ccaa3e96 100644 --- a/dojo/jira_link/helper.py +++ b/dojo/jira_link/helper.py @@ -18,7 +18,7 @@ from requests.auth import HTTPBasicAuth from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.forms import JIRAEngagementForm, JIRAProjectForm from dojo.models import ( Engagement, @@ -763,20 +763,19 @@ def push_to_jira(obj, *args, **kwargs): if isinstance(obj, Finding): if obj.has_finding_group: logger.debug("pushing finding group for %s to JIRA", obj) - return push_finding_group_to_jira(obj.finding_group.id, *args, **kwargs) - return push_finding_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.finding_group.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Finding_Group): - return push_finding_group_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_finding_group_to_jira, obj.id, *args, **kwargs) if isinstance(obj, Engagement): - return push_engagement_to_jira(obj.id, *args, **kwargs) + return dojo_dispatch_task(push_engagement_to_jira, obj.id, *args, **kwargs) logger.error("unsupported object passed to push_to_jira: %s %i %s", obj.__name__, obj.id, obj) return None # we need thre separate celery tasks due to the decorators we're using to map to/from ids -@dojo_async_task @app.task def push_finding_to_jira(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -789,7 +788,6 @@ def push_finding_to_jira(finding_id, *args, **kwargs): return add_jira_issue(finding, *args, **kwargs) -@dojo_async_task @app.task def push_finding_group_to_jira(finding_group_id, *args, **kwargs): finding_group = get_object_or_none(Finding_Group, id=finding_group_id) @@ -806,7 +804,6 @@ def push_finding_group_to_jira(finding_group_id, *args, **kwargs): return add_jira_issue(finding_group, *args, **kwargs) -@dojo_async_task @app.task def push_engagement_to_jira(engagement_id, *args, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -815,8 +812,8 @@ def push_engagement_to_jira(engagement_id, *args, **kwargs): return None if engagement.has_jira_issue: - return update_epic(engagement.id, *args, **kwargs) - return add_epic(engagement.id, *args, **kwargs) + return dojo_dispatch_task(update_epic, engagement.id, *args, **kwargs) + return dojo_dispatch_task(add_epic, engagement.id, *args, **kwargs) def add_issues_to_epic(jira, obj, epic_id, issue_keys, *, ignore_epics=True): @@ -1396,7 +1393,6 @@ def jira_check_attachment(issue, source_file_name): return file_exists -@dojo_async_task @app.task def close_epic(engagement_id, push_to_jira, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1445,7 +1441,6 @@ def close_epic(engagement_id, push_to_jira, **kwargs): return False -@dojo_async_task @app.task def update_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1492,7 +1487,6 @@ def update_epic(engagement_id, **kwargs): return False -@dojo_async_task @app.task def add_epic(engagement_id, **kwargs): engagement = get_object_or_none(Engagement, id=engagement_id) @@ -1601,10 +1595,9 @@ def add_comment(obj, note, *, force_push=False, **kwargs): return False # Call the internal task with IDs (runs synchronously within this task) - return add_comment_internal(jira_issue.id, note.id, force_push=force_push, **kwargs) + return dojo_dispatch_task(add_comment_internal, jira_issue.id, note.id, force_push=force_push, **kwargs) -@dojo_async_task @app.task def add_comment_internal(jira_issue_id, note_id, *, force_push=False, **kwargs): """Internal Celery task that adds a comment to a JIRA issue.""" diff --git a/dojo/management/commands/dedupe.py b/dojo/management/commands/dedupe.py index a4cbee519a6..3d17f4a7fe9 100644 --- a/dojo/management/commands/dedupe.py +++ b/dojo/management/commands/dedupe.py @@ -131,14 +131,25 @@ def _run_dedupe(self, *, restrict_to_parsers, hash_code_only, dedupe_only, dedup mass_model_updater(Finding, findings, do_dedupe_finding_task_internal, fields=None, order="desc", page_size=100, log_prefix="deduplicating ") else: # async tasks only need the id - mass_model_updater(Finding, findings.only("id"), lambda f: do_dedupe_finding_task(f.id), fields=None, order="desc", log_prefix="deduplicating ") + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + mass_model_updater( + Finding, + findings.only("id"), + lambda f: dojo_dispatch_task(do_dedupe_finding_task, f.id), + fields=None, + order="desc", + log_prefix="deduplicating ", + ) if dedupe_sync: # update the grading (if enabled) and only useful in sync mode # in async mode the background task that grades products every hour will pick it up logger.debug("Updating grades for products...") for product in Product.objects.all(): - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) logger.info("######## Done deduplicating (%s) ########", ("foreground" if dedupe_sync else "tasks submitted to celery")) else: @@ -185,7 +196,9 @@ def _dedupe_batch_mode(self, findings_queryset, *, dedupe_sync: bool = True): else: # Asynchronous: submit task with finding IDs logger.debug(f"Submitting async batch task for {len(batch_finding_ids)} findings for test {test_id}") - do_dedupe_batch_task(batch_finding_ids) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(do_dedupe_batch_task, batch_finding_ids) total_processed += len(batch_finding_ids) batch_finding_ids = [] diff --git a/dojo/models.py b/dojo/models.py index 2281a2bfac8..5f470c83d93 100644 --- a/dojo/models.py +++ b/dojo/models.py @@ -1092,7 +1092,9 @@ def save(self, *args, **kwargs): super(Product, product).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(self, products, severities=severities) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, self, products, severities=severities) def clean(self): sla_days = [self.critical, self.high, self.medium, self.low] @@ -1252,7 +1254,9 @@ def save(self, *args, **kwargs): super(SLA_Configuration, sla_config).save() # launch the async task to update all finding sla expiration dates from dojo.sla_config.helpers import async_update_sla_expiration_dates_sla_config_sync # noqa: I001, PLC0415 circular import - async_update_sla_expiration_dates_sla_config_sync(sla_config, Product.objects.filter(id=self.id)) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_update_sla_expiration_dates_sla_config_sync, sla_config, Product.objects.filter(id=self.id)) def get_absolute_url(self): return reverse("view_product", args=[str(self.id)]) diff --git a/dojo/notifications/helper.py b/dojo/notifications/helper.py index c4458daec01..7318bae91af 100644 --- a/dojo/notifications/helper.py +++ b/dojo/notifications/helper.py @@ -18,7 +18,8 @@ from dojo import __version__ as dd_version from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task, we_want_async +from dojo.celery_dispatch import dojo_dispatch_task +from dojo.decorators import we_want_async from dojo.labels import get_labels from dojo.models import ( Alerts, @@ -45,6 +46,26 @@ labels = get_labels() +def get_manager_class_instance(): + default_manager = NotificationManager + notification_manager_class = default_manager + if isinstance( + ( + notification_manager := getattr( + settings, + "NOTIFICATION_MANAGER", + default_manager, + ) + ), + str, + ): + with suppress(ModuleNotFoundError): + module_name, _separator, class_name = notification_manager.rpartition(".") + module = importlib.import_module(module_name) + notification_manager_class = getattr(module, class_name) + return notification_manager_class() + + def create_notification( event: str | None = None, title: str | None = None, @@ -62,23 +83,7 @@ def create_notification( **kwargs: dict, ) -> None: """Create an instance of a NotificationManager and dispatch the notification.""" - default_manager = NotificationManager - notification_manager_class = default_manager - if isinstance( - ( - notification_manager := getattr( - settings, - "NOTIFICATION_MANAGER", - default_manager, - ) - ), - str, - ): - with suppress(ModuleNotFoundError): - module_name, _separator, class_name = notification_manager.rpartition(".") - module = importlib.import_module(module_name) - notification_manager_class = getattr(module, class_name) - notification_manager_class().create_notification( + get_manager_class_instance().create_notification( event=event, title=title, finding=finding, @@ -199,8 +204,6 @@ class SlackNotificationManger(NotificationManagerHelpers): """Manger for slack notifications and their helpers.""" - @dojo_async_task - @app.task def send_slack_notification( self, event: str, @@ -317,8 +320,6 @@ class MSTeamsNotificationManger(NotificationManagerHelpers): """Manger for Microsoft Teams notifications and their helpers.""" - @dojo_async_task - @app.task def send_msteams_notification( self, event: str, @@ -368,8 +369,6 @@ class EmailNotificationManger(NotificationManagerHelpers): """Manger for email notifications and their helpers.""" - @dojo_async_task - @app.task def send_mail_notification( self, event: str, @@ -420,8 +419,6 @@ class WebhookNotificationManger(NotificationManagerHelpers): ERROR_PERMANENT = "permanent" ERROR_TEMPORARY = "temporary" - @dojo_async_task - @app.task def send_webhooks_notification( self, event: str, @@ -480,11 +477,7 @@ def send_webhooks_notification( endpoint.first_error = now endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_TMP # In case of failure within one day, endpoint can be deactivated temporally only for one minute - self._webhook_reactivation.apply_async( - args=[self], - kwargs={"endpoint_id": endpoint.pk}, - countdown=60, - ) + webhook_reactivation.apply_async(kwargs={"endpoint_id": endpoint.pk}, countdown=60) # There is no reason to keep endpoint active if it is returning 4xx errors else: endpoint.status = Notification_Webhooks.Status.STATUS_INACTIVE_PERMANENT @@ -559,7 +552,6 @@ def _test_webhooks_notification(self, endpoint: Notification_Webhooks) -> None: # in "send_webhooks_notification", we are doing deeper analysis, why it failed # for now, "raise_for_status" should be enough - @app.task(ignore_result=True) def _webhook_reactivation(self, endpoint_id: int, **_kwargs: dict): endpoint = Notification_Webhooks.objects.get(pk=endpoint_id) # User already changed status of endpoint @@ -832,9 +824,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Slack Notification") - self._get_manager_instance("slack").send_slack_notification( + dojo_dispatch_task( + send_slack_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -844,9 +837,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending MSTeams Notification") - self._get_manager_instance("msteams").send_msteams_notification( + dojo_dispatch_task( + send_msteams_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -856,9 +850,10 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Mail Notification") - self._get_manager_instance("mail").send_mail_notification( + dojo_dispatch_task( + send_mail_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) @@ -868,13 +863,43 @@ def _process_notifications( notifications.other, ): logger.debug("Sending Webhooks Notification") - self._get_manager_instance("webhooks").send_webhooks_notification( + dojo_dispatch_task( + send_webhooks_notification, event, - user=notifications.user, + user_id=getattr(notifications.user, "id", None), **kwargs, ) +@app.task +def send_slack_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("slack").send_slack_notification(event, user=user, **kwargs) + + +@app.task +def send_msteams_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("msteams").send_msteams_notification(event, user=user, **kwargs) + + +@app.task +def send_mail_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("mail").send_mail_notification(event, user=user, **kwargs) + + +@app.task +def send_webhooks_notification(event: str, user_id: int | None = None, **kwargs: dict) -> None: + user = Dojo_User.objects.get(pk=user_id) if user_id else None + get_manager_class_instance()._get_manager_instance("webhooks").send_webhooks_notification(event, user=user, **kwargs) + + +@app.task(ignore_result=True) +def webhook_reactivation(endpoint_id: int, **_kwargs: dict) -> None: + get_manager_class_instance()._get_manager_instance("webhooks")._webhook_reactivation(endpoint_id=endpoint_id) + + @app.task(ignore_result=True) def webhook_status_cleanup(*_args: list, **_kwargs: dict): # If some endpoint was affected by some outage (5xx, 429, Timeout) but it was clean during last 24 hours, @@ -902,4 +927,4 @@ def webhook_status_cleanup(*_args: list, **_kwargs: dict): ) for endpoint in broken_endpoints: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=endpoint.pk) + manager._webhook_reactivation(endpoint_id=endpoint.pk) diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index 8247cad4fa8..cdb0750f317 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -5,14 +5,12 @@ from django.db.models import Q from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.location.models import Location from dojo.models import Endpoint, Engagement, Finding, Product, Test logger = logging.getLogger(__name__) -@dojo_async_task @app.task def propagate_tags_on_product(product_id, *args, **kwargs): with contextlib.suppress(Product.DoesNotExist): diff --git a/dojo/sla_config/helpers.py b/dojo/sla_config/helpers.py index da5899a85b0..045456f38d7 100644 --- a/dojo/sla_config/helpers.py +++ b/dojo/sla_config/helpers.py @@ -1,14 +1,12 @@ import logging from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.models import Finding, Product, SLA_Configuration, System_Settings from dojo.utils import get_custom_method, mass_model_updater logger = logging.getLogger(__name__) -@dojo_async_task @app.task def async_update_sla_expiration_dates_sla_config_sync(sla_config: SLA_Configuration, products: list[Product], *args, severities: list[str] | None = None, **kwargs): if method := get_custom_method("FINDING_SLA_EXPIRATION_CALCULATION_METHOD"): diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 58682421d0b..0fea7ae8ad5 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -4,6 +4,7 @@ from django.db.models import signals from django.dispatch import receiver +from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.models import Endpoint, Engagement, Finding, Product, Test from dojo.product import helpers as async_product_funcs @@ -20,7 +21,7 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs): running_async_process = instance.running_async_process # Check if the async process is already running to avoid calling it a second time if not running_async_process and inherit_product_tags(instance): - async_product_funcs.propagate_tags_on_product(instance.id, countdown=5) + dojo_dispatch_task(async_product_funcs.propagate_tags_on_product, instance.id, countdown=5) instance.running_async_process = True diff --git a/dojo/tasks.py b/dojo/tasks.py index b934abc9f02..1bbe104783b 100644 --- a/dojo/tasks.py +++ b/dojo/tasks.py @@ -2,6 +2,7 @@ from datetime import timedelta import pghistory +from celery import Task from celery.utils.log import get_task_logger from django.apps import apps from django.conf import settings @@ -12,7 +13,7 @@ from dojo.auditlog import run_flush_auditlog from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.helper import fix_loop_duplicates from dojo.location.models import Location from dojo.management.commands.jira_status_reconciliation import jira_status_reconciliation @@ -31,7 +32,7 @@ def log_generic_alert(source, title, description): @app.task(bind=True) -def add_alerts(self, runinterval): +def add_alerts(self, runinterval, *args, **kwargs): now = timezone.now() upcoming_engagements = Engagement.objects.filter(target_start__gt=now + timedelta(days=3), target_start__lt=now + timedelta(days=3) + runinterval).order_by("target_start") @@ -73,7 +74,7 @@ def add_alerts(self, runinterval): if system_settings.enable_product_grade: products = Product.objects.all() for product in products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) @app.task(bind=True) @@ -170,11 +171,18 @@ def _async_dupe_delete_impl(): if system_settings.enable_product_grade: logger.info("performing batch product grading for %s products", len(affected_products)) for product in affected_products: - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) -@app.task(ignore_result=False) +@app.task(ignore_result=False, base=Task) def celery_status(): + """ + Simple health check task to verify Celery is running. + + Uses base Task class (not PgHistoryTask) since it doesn't need: + - User context tracking + - Pghistory context (no database modifications) + """ return True @@ -242,7 +250,6 @@ def clear_sessions(*args, **kwargs): call_command("clearsessions") -@dojo_async_task @app.task def update_watson_search_index_for_model(model_name, pk_list, *args, **kwargs): """ diff --git a/dojo/templatetags/display_tags.py b/dojo/templatetags/display_tags.py index 0fd6b604f06..18c4b7f6f4a 100644 --- a/dojo/templatetags/display_tags.py +++ b/dojo/templatetags/display_tags.py @@ -363,7 +363,9 @@ def product_grade(product): if system_settings.enable_product_grade and product: prod_numeric_grade = product.prod_numeric_grade if not prod_numeric_grade or prod_numeric_grade is None: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) if prod_numeric_grade: if prod_numeric_grade >= system_settings.product_grade_a: grade = "A" diff --git a/dojo/test/views.py b/dojo/test/views.py index 96f7c54e470..c8b52bb6f14 100644 --- a/dojo/test/views.py +++ b/dojo/test/views.py @@ -27,6 +27,7 @@ from dojo.authorization.authorization import user_has_permission_or_403 from dojo.authorization.authorization_decorators import user_is_authorized from dojo.authorization.roles_permissions import Permissions +from dojo.celery_dispatch import dojo_dispatch_task from dojo.engagement.queries import get_authorized_engagements from dojo.filters import FindingFilter, FindingFilterWithoutObjectLookups, TemplateFindingFilter, TestImportFilter from dojo.finding.queries import prefetch_for_findings @@ -345,7 +346,7 @@ def copy_test(request, tid): engagement = form.cleaned_data.get("engagement") product = test.engagement.product test_copy = test.copy(engagement=engagement) - calculate_grade(product.id) + dojo_dispatch_task(calculate_grade, product.id) messages.add_message( request, messages.SUCCESS, diff --git a/dojo/tools/tool_issue_updater.py b/dojo/tools/tool_issue_updater.py index 854fb989113..8211e166eed 100644 --- a/dojo/tools/tool_issue_updater.py +++ b/dojo/tools/tool_issue_updater.py @@ -3,7 +3,7 @@ import pghistory from dojo.celery import app -from dojo.decorators import dojo_async_task +from dojo.celery_dispatch import dojo_dispatch_task from dojo.models import Finding from dojo.tools.api_sonarqube.parser import SCAN_SONARQUBE_API from dojo.tools.api_sonarqube.updater import SonarQubeApiUpdater @@ -15,7 +15,7 @@ def async_tool_issue_update(finding, *args, **kwargs): if is_tool_issue_updater_needed(finding): - tool_issue_updater(finding.id) + dojo_dispatch_task(tool_issue_updater, finding.id) def is_tool_issue_updater_needed(finding, *args, **kwargs): @@ -23,7 +23,6 @@ def is_tool_issue_updater_needed(finding, *args, **kwargs): return test_type.name == SCAN_SONARQUBE_API -@dojo_async_task @app.task def tool_issue_updater(finding_id, *args, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -37,7 +36,6 @@ def tool_issue_updater(finding_id, *args, **kwargs): SonarQubeApiUpdater().update_sonarqube_finding(finding) -@dojo_async_task @app.task def update_findings_from_source_issues(**kwargs): # Wrap with pghistory context for audit trail diff --git a/dojo/utils.py b/dojo/utils.py index 980cd107659..ba1b5ed0d7c 100644 --- a/dojo/utils.py +++ b/dojo/utils.py @@ -47,7 +47,6 @@ from dojo.authorization.roles_permissions import Permissions from dojo.celery import app -from dojo.decorators import dojo_async_task from dojo.finding.queries import get_authorized_findings from dojo.github import ( add_external_issue_github, @@ -1057,7 +1056,6 @@ def handle_uploaded_selenium(f, cred): cred.save() -@dojo_async_task @app.task def add_external_issue(finding_id, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1073,7 +1071,6 @@ def add_external_issue(finding_id, external_issue_provider, **kwargs): add_external_issue_github(finding, prod, eng) -@dojo_async_task @app.task def update_external_issue(finding_id, old_status, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1088,7 +1085,6 @@ def update_external_issue(finding_id, old_status, external_issue_provider, **kwa update_external_issue_github(finding, prod, eng) -@dojo_async_task @app.task def close_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1103,7 +1099,6 @@ def close_external_issue(finding_id, note, external_issue_provider, **kwargs): close_external_issue_github(finding, note, prod, eng) -@dojo_async_task @app.task def reopen_external_issue(finding_id, note, external_issue_provider, **kwargs): finding = get_object_or_none(Finding, id=finding_id) @@ -1259,7 +1254,6 @@ def grade_product(crit, high, med, low): return max(health, 5) -@dojo_async_task @app.task def calculate_grade(product_id, *args, **kwargs): product = get_object_or_none(Product, id=product_id) @@ -1317,7 +1311,9 @@ def calculate_grade_internal(product, *args, **kwargs): def perform_product_grading(product): system_settings = System_Settings.objects.get() if system_settings.enable_product_grade: - calculate_grade(product.id) + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(calculate_grade, product.id) def get_celery_worker_status(): @@ -2037,130 +2033,187 @@ def is_finding_groups_enabled(): return get_system_setting("enable_finding_groups") -class async_delete: - def __init__(self, *args, **kwargs): - self.mapping = { - "Product_Type": [ - (Endpoint, "product__prod_type__id"), - (Finding, "test__engagement__product__prod_type__id"), - (Test, "engagement__product__prod_type__id"), - (Engagement, "product__prod_type__id"), - (Product, "prod_type__id")], - "Product": [ - (Endpoint, "product__id"), - (Finding, "test__engagement__product__id"), - (Test, "engagement__product__id"), - (Engagement, "product__id")], - "Engagement": [ - (Finding, "test__engagement__id"), - (Test, "engagement__id")], - "Test": [(Finding, "test__id")], - } +# Mapping of object types to their related models for cascading deletes +ASYNC_DELETE_MAPPING = { + "Product_Type": [ + (Endpoint, "product__prod_type__id"), + (Finding, "test__engagement__product__prod_type__id"), + (Test, "engagement__product__prod_type__id"), + (Engagement, "product__prod_type__id"), + (Product, "prod_type__id")], + "Product": [ + (Endpoint, "product__id"), + (Finding, "test__engagement__product__id"), + (Test, "engagement__product__id"), + (Engagement, "product__id")], + "Engagement": [ + (Finding, "test__engagement__id"), + (Test, "engagement__id")], + "Test": [(Finding, "test__id")], +} + + +def _get_object_name(obj): + """Get the class name of an object or model class.""" + if obj.__class__.__name__ == "ModelBase": + return obj.__name__ + return obj.__class__.__name__ - @dojo_async_task - @app.task - def delete_chunk(self, objects, **kwargs): - # Now delete all objects with retry for deadlocks - max_retries = 3 - for obj in objects: - retry_count = 0 - while retry_count < max_retries: - try: - obj.delete() - break # Success, exit retry loop - except OperationalError as e: - error_msg = str(e) - if "deadlock detected" in error_msg.lower(): - retry_count += 1 - if retry_count < max_retries: - # Exponential backoff with jitter - wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 - logger.warning( - f"ASYNC_DELETE: Deadlock detected deleting {self.get_object_name(obj)} {obj.pk}, " - f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", - ) - time.sleep(wait_time) - # Refresh object from DB before retry - obj.refresh_from_db() - else: - logger.error( - f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {self.get_object_name(obj)} {obj.pk}: {e}", - ) - raise + +@app.task +def async_delete_chunk_task(objects, **kwargs): + """ + Module-level Celery task to delete a chunk of objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + max_retries = 3 + for obj in objects: + retry_count = 0 + while retry_count < max_retries: + try: + obj.delete() + break # Success, exit retry loop + except OperationalError as e: + error_msg = str(e) + if "deadlock detected" in error_msg.lower(): + retry_count += 1 + if retry_count < max_retries: + # Exponential backoff with jitter + wait_time = (2 ** retry_count) + random.uniform(0, 1) # noqa: S311 + logger.warning( + f"ASYNC_DELETE: Deadlock detected deleting {_get_object_name(obj)} {obj.pk}, " + f"retrying ({retry_count}/{max_retries}) after {wait_time:.2f}s", + ) + time.sleep(wait_time) + # Refresh object from DB before retry + obj.refresh_from_db() else: - # Not a deadlock, re-raise + logger.error( + f"ASYNC_DELETE: Deadlock persisted after {max_retries} retries for {_get_object_name(obj)} {obj.pk}: {e}", + ) raise - except AssertionError: - logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") - # The id must be None - # The object has already been deleted elsewhere - break - except LogEntry.MultipleObjectsReturned: - # Delete the log entrys first, then delete - LogEntry.objects.filter( - content_type=ContentType.objects.get_for_model(obj.__class__), - object_pk=str(obj.pk), - action=LogEntry.Action.DELETE, - ).delete() - # Now delete the object again (no retry needed for this case) - obj.delete() - break - - @dojo_async_task - @app.task - def delete(self, obj, **kwargs): - logger.debug("ASYNC_DELETE: Deleting " + self.get_object_name(obj) + ": " + str(obj)) - model_list = self.mapping.get(self.get_object_name(obj), None) - if model_list: - # The object to be deleted was found in the object list - self.crawl(obj, model_list) - else: - # The object is not supported in async delete, delete normally - logger.debug("ASYNC_DELETE: " + self.get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) - obj.delete() - - @dojo_async_task - @app.task - def crawl(self, obj, model_list, **kwargs): - logger.debug("ASYNC_DELETE: Crawling " + self.get_object_name(obj) + ": " + str(obj)) - with Endpoint.allow_endpoint_init(): # TODO: Delete this after the move to Locations - for model_info in model_list: - task_results = [] - model = model_info[0] - model_query = model_info[1] - filter_dict = {model_query: obj.id} - # Only fetch the IDs since we will make a list of IDs in the following function call - objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") - logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + self.get_object_name(model) + "s in chunks") - chunks = self.chunk_list(model, objects_to_delete) - for chunk in chunks: - logger.debug(f"deleting {len(chunk)} {self.get_object_name(model)}") - result = self.delete_chunk(chunk) - # Collect async task results to wait for them all at once - if hasattr(result, "get"): - task_results.append(result) - # Wait for all chunk deletions to complete (they run in parallel) - for task_result in task_results: - task_result.get(timeout=300) # 5 minute timeout per chunk - # Now delete the main object after all chunks are done - result = self.delete_chunk([obj]) - # Wait for final deletion to complete + else: + # Not a deadlock, re-raise + raise + except AssertionError: + logger.debug("ASYNC_DELETE: object has already been deleted elsewhere. Skipping") + # The id must be None + # The object has already been deleted elsewhere + break + except LogEntry.MultipleObjectsReturned: + # Delete the log entrys first, then delete + LogEntry.objects.filter( + content_type=ContentType.objects.get_for_model(obj.__class__), + object_pk=str(obj.pk), + action=LogEntry.Action.DELETE, + ).delete() + # Now delete the object again (no retry needed for this case) + obj.delete() + break + + +@app.task +def async_delete_crawl_task(obj, model_list, **kwargs): + """ + Module-level Celery task to crawl and delete related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Crawling " + _get_object_name(obj) + ": " + str(obj)) + for model_info in model_list: + task_results = [] + model = model_info[0] + model_query = model_info[1] + filter_dict = {model_query: obj.id} + # Only fetch the IDs since we will make a list of IDs in the following function call + objects_to_delete = model.objects.only("id").filter(**filter_dict).distinct().order_by("id") + logger.debug("ASYNC_DELETE: Deleting " + str(len(objects_to_delete)) + " " + _get_object_name(model) + "s in chunks") + chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") + chunks = [objects_to_delete[i:i + chunk_size] for i in range(0, len(objects_to_delete), chunk_size)] + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunks)) + " chunks of " + str(chunk_size)) + for chunk in chunks: + logger.debug(f"deleting {len(chunk)} {_get_object_name(model)}") + result = dojo_dispatch_task(async_delete_chunk_task, list(chunk)) + # Collect async task results to wait for them all at once if hasattr(result, "get"): - result.get(timeout=300) # 5 minute timeout - logger.debug("ASYNC_DELETE: Successfully deleted " + self.get_object_name(obj) + ": " + str(obj)) + task_results.append(result) + # Wait for all chunk deletions to complete (they run in parallel) + for task_result in task_results: + task_result.get(timeout=300) # 5 minute timeout per chunk + # Now delete the main object after all chunks are done + result = dojo_dispatch_task(async_delete_chunk_task, [obj]) + # Wait for final deletion to complete + if hasattr(result, "get"): + result.get(timeout=300) # 5 minute timeout + logger.debug("ASYNC_DELETE: Successfully deleted " + _get_object_name(obj) + ": " + str(obj)) - def chunk_list(self, model, full_list): + +@app.task +def async_delete_task(obj, **kwargs): + """ + Module-level Celery task to delete an object and its related objects. + + Accepts **kwargs for async_user and _pgh_context injected by dojo_dispatch_task. + Uses PgHistoryTask base class (default) to preserve pghistory context for audit trail. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + logger.debug("ASYNC_DELETE: Deleting " + _get_object_name(obj) + ": " + str(obj)) + model_list = ASYNC_DELETE_MAPPING.get(_get_object_name(obj)) + if model_list: + # The object to be deleted was found in the object list + dojo_dispatch_task(async_delete_crawl_task, obj, model_list) + else: + # The object is not supported in async delete, delete normally + logger.debug("ASYNC_DELETE: " + _get_object_name(obj) + " async delete not supported. Deleteing normally: " + str(obj)) + obj.delete() + + +class async_delete: + + """ + Entry point class for async object deletion. + + Usage: + async_del = async_delete() + async_del.delete(instance) + + This class dispatches deletion to module-level Celery tasks via dojo_dispatch_task, + which properly handles user context injection and pghistory context. + """ + + def __init__(self, *args, **kwargs): + # Keep mapping reference for backwards compatibility + self.mapping = ASYNC_DELETE_MAPPING + + def delete(self, obj, **kwargs): + """ + Entry point to delete an object asynchronously. + + Dispatches to async_delete_task via dojo_dispatch_task to ensure proper + handling of async_user and _pgh_context. + """ + from dojo.celery_dispatch import dojo_dispatch_task # noqa: PLC0415 circular import + + dojo_dispatch_task(async_delete_task, obj, **kwargs) + + # Keep helper methods for backwards compatibility and potential direct use + @staticmethod + def get_object_name(obj): + return _get_object_name(obj) + + @staticmethod + def chunk_list(model, full_list): chunk_size = get_setting("ASYNC_OBEJECT_DELETE_CHUNK_SIZE") - # Break the list of objects into "chunk_size" lists chunk_list = [full_list[i:i + chunk_size] for i in range(0, len(full_list), chunk_size)] - logger.debug("ASYNC_DELETE: Split " + self.get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) + logger.debug("ASYNC_DELETE: Split " + _get_object_name(model) + " into " + str(len(chunk_list)) + " chunks of " + str(chunk_size)) return chunk_list - def get_object_name(self, obj): - if obj.__class__.__name__ == "ModelBase": - return obj.__name__ - return obj.__class__.__name__ - @receiver(user_logged_in) def log_user_login(sender, request, user, **kwargs): diff --git a/unittests/test_async_delete.py b/unittests/test_async_delete.py new file mode 100644 index 00000000000..341723e8296 --- /dev/null +++ b/unittests/test_async_delete.py @@ -0,0 +1,313 @@ +""" +Unit tests for async_delete functionality. + +These tests verify that the async_delete class works correctly with dojo_dispatch_task, +which injects async_user and _pgh_context kwargs into task calls. + +The original bug was that @app.task decorated instance methods didn't properly handle +the injected kwargs, causing TypeError: unexpected keyword argument 'async_user'. +""" +import logging + +from crum import impersonate +from django.contrib.auth.models import User +from django.test import override_settings +from django.utils import timezone + +from dojo.models import Engagement, Finding, Product, Product_Type, Test, Test_Type, UserContactInfo +from dojo.utils import async_delete + +from .dojo_test_case import DojoTestCase + +logger = logging.getLogger(__name__) + + +class TestAsyncDelete(DojoTestCase): + + """ + Test async_delete functionality with dojo_dispatch_task kwargs injection. + + These tests use block_execution=True and crum.impersonate to run tasks synchronously, + which allows errors to surface immediately rather than being lost in background workers. + """ + + def setUp(self): + """Set up test user with block_execution=True and disable unneeded features.""" + super().setUp() + + # Create test user with block_execution=True to run tasks synchronously + self.testuser = User.objects.create( + username="test_async_delete_user", + is_staff=True, + is_superuser=True, + ) + UserContactInfo.objects.create(user=self.testuser, block_execution=True) + + # Log in as the test user (for API client) + self.client.force_login(self.testuser) + + # Disable features that might interfere with deletion + self.system_settings(enable_product_grade=False) + self.system_settings(enable_github=False) + self.system_settings(enable_jira=False) + + # Create base test data + self.product_type = Product_Type.objects.create(name="Test Product Type for Async Delete") + self.test_type = Test_Type.objects.get_or_create(name="Manual Test")[0] + + def tearDown(self): + """Clean up any remaining test data.""" + # Clean up in reverse order of dependencies + Finding.objects.filter(test__engagement__product__prod_type=self.product_type).delete() + Test.objects.filter(engagement__product__prod_type=self.product_type).delete() + Engagement.objects.filter(product__prod_type=self.product_type).delete() + Product.objects.filter(prod_type=self.product_type).delete() + self.product_type.delete() + + super().tearDown() + + def _create_product(self, name="Test Product"): + """Helper to create a product for testing.""" + return Product.objects.create( + name=name, + description="Test product for async delete", + prod_type=self.product_type, + ) + + def _create_engagement(self, product, name="Test Engagement"): + """Helper to create an engagement for testing.""" + return Engagement.objects.create( + name=name, + product=product, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_test(self, engagement, name="Test"): + """Helper to create a test for testing.""" + return Test.objects.create( + engagement=engagement, + test_type=self.test_type, + target_start=timezone.now(), + target_end=timezone.now(), + ) + + def _create_finding(self, test, title="Test Finding"): + """Helper to create a finding for testing.""" + return Finding.objects.create( + test=test, + title=title, + severity="High", + description="Test finding for async delete", + mitigation="Test mitigation", + impact="Test impact", + reporter=self.testuser, + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_simple_object(self): + """ + Test that async_delete works for a simple object (Finding). + + Finding is not in the async_delete mapping, so it falls back to direct delete. + This tests that the module-level task accepts **kwargs properly. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # This would raise TypeError before the fix: + # TypeError: delete() got an unexpected keyword argument 'async_user' + async_del = async_delete() + async_del.delete(finding) + + # Verify the finding was deleted + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_test_with_findings(self): + """ + Test that async_delete cascades deletion for Test objects. + + Test is in the async_delete mapping and should cascade delete its findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding1 = self._create_finding(test, "Finding 1") + finding2 = self._create_finding(test, "Finding 2") + + test_pk = test.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the test (should cascade to findings) + async_del = async_delete() + async_del.delete(test) + + # Verify all objects were deleted + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted", + ) + self.assertFalse( + Finding.objects.filter(pk=finding1_pk).exists(), + "Finding 1 should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding2_pk).exists(), + "Finding 2 should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_engagement_with_tests(self): + """ + Test that async_delete cascades deletion for Engagement objects. + + Engagement is in the async_delete mapping and should cascade delete + its tests and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test1 = self._create_test(engagement, "Test 1") + test2 = self._create_test(engagement, "Test 2") + finding1 = self._create_finding(test1, "Finding in Test 1") + finding2 = self._create_finding(test2, "Finding in Test 2") + + engagement_pk = engagement.pk + test1_pk = test1.pk + test2_pk = test2.pk + finding1_pk = finding1.pk + finding2_pk = finding2.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the engagement (should cascade to tests and findings) + async_del = async_delete() + async_del.delete(engagement) + + # Verify all objects were deleted + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted", + ) + self.assertFalse( + Test.objects.filter(pk__in=[test1_pk, test2_pk]).exists(), + "Tests should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk__in=[finding1_pk, finding2_pk]).exists(), + "Findings should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_product_with_hierarchy(self): + """ + Test that async_delete cascades deletion for Product objects. + + Product is in the async_delete mapping and should cascade delete + its engagements, tests, and findings. + """ + product = self._create_product() + engagement = self._create_engagement(product) + test = self._create_test(engagement) + finding = self._create_finding(test) + + product_pk = product.pk + engagement_pk = engagement.pk + test_pk = test.pk + finding_pk = finding.pk + + # Use impersonate to set current user context (required for block_execution to work) + with impersonate(self.testuser): + # Delete the product (should cascade to everything) + async_del = async_delete() + async_del.delete(product) + + # Verify all objects were deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted", + ) + self.assertFalse( + Engagement.objects.filter(pk=engagement_pk).exists(), + "Engagement should be deleted via cascade", + ) + self.assertFalse( + Test.objects.filter(pk=test_pk).exists(), + "Test should be deleted via cascade", + ) + self.assertFalse( + Finding.objects.filter(pk=finding_pk).exists(), + "Finding should be deleted via cascade", + ) + + @override_settings(ASYNC_OBJECT_DELETE=True) + def test_async_delete_accepts_sync_kwarg(self): + """ + Test that async_delete passes through the sync kwarg properly. + + The sync=True kwarg forces synchronous execution for the top-level task. + However, nested task dispatches still need user context to run synchronously, + so we use impersonate here as well. + """ + product = self._create_product() + product_pk = product.pk + + # Use impersonate to ensure nested tasks also run synchronously + with impersonate(self.testuser): + # Explicitly pass sync=True + async_del = async_delete() + async_del.delete(product, sync=True) + + # Verify the product was deleted + self.assertFalse( + Product.objects.filter(pk=product_pk).exists(), + "Product should be deleted with sync=True", + ) + + def test_async_delete_helper_methods(self): + """ + Test that static helper methods on async_delete class still work. + + These are kept for backwards compatibility. + """ + product = self._create_product() + + # Test get_object_name + self.assertEqual( + async_delete.get_object_name(product), + "Product", + "get_object_name should return class name", + ) + + # Test get_object_name with model class + self.assertEqual( + async_delete.get_object_name(Product), + "Product", + "get_object_name should work with model class", + ) + + def test_async_delete_mapping_preserved(self): + """ + Test that the mapping attribute is preserved on async_delete instances. + + This ensures backwards compatibility for code that might access the mapping. + """ + async_del = async_delete() + + # Verify mapping exists and has expected keys + self.assertIsNotNone(async_del.mapping) + self.assertIn("Product", async_del.mapping) + self.assertIn("Product_Type", async_del.mapping) + self.assertIn("Engagement", async_del.mapping) + self.assertIn("Test", async_del.mapping) diff --git a/unittests/test_jira_import_and_pushing_api.py b/unittests/test_jira_import_and_pushing_api.py index e1a8284698e..eb762a74313 100644 --- a/unittests/test_jira_import_and_pushing_api.py +++ b/unittests/test_jira_import_and_pushing_api.py @@ -981,7 +981,7 @@ def test_engagement_epic_mapping_disabled_no_epic_and_push_findings(self): @patch("dojo.jira_link.helper.can_be_pushed_to_jira", return_value=(True, None, None)) @patch("dojo.jira_link.helper.is_push_all_issues", return_value=False) @patch("dojo.jira_link.helper.push_to_jira", return_value=None) - @patch("dojo.notifications.helper.WebhookNotificationManger.send_webhooks_notification") + @patch("dojo.notifications.helper.send_webhooks_notification") def test_bulk_edit_mixed_findings_and_groups_jira_push_bug(self, mock_webhooks, mock_push_to_jira, mock_is_push_all_issues, mock_can_be_pushed): """ Test the bug in bulk edit: when bulk editing findings where some are in groups diff --git a/unittests/test_notifications.py b/unittests/test_notifications.py index fc547f526a1..2e897702fc4 100644 --- a/unittests/test_notifications.py +++ b/unittests/test_notifications.py @@ -680,7 +680,7 @@ def test_webhook_reactivation(self): with self.subTest("active"): wh = Notification_Webhooks.objects.filter(owner=None).first() manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE) @@ -699,7 +699,7 @@ def test_webhook_reactivation(self): with self.assertLogs("dojo.notifications.helper", level="DEBUG") as cm: manager = WebhookNotificationManger() - manager._webhook_reactivation(manager, endpoint_id=wh.pk) + manager._webhook_reactivation(endpoint_id=wh.pk) updated_wh = Notification_Webhooks.objects.filter(owner=None).first() self.assertEqual(updated_wh.status, Notification_Webhooks.Status.STATUS_ACTIVE_TMP)