diff --git a/dojo/finding/helper.py b/dojo/finding/helper.py index 7808fdd7cf5..0d2455fc777 100644 --- a/dojo/finding/helper.py +++ b/dojo/finding/helper.py @@ -986,14 +986,15 @@ def add_locations(finding, form, *, replace=False): return set(locations_to_associate) -def sanitize_vulnerability_ids(vulnerability_ids) -> None: +def sanitize_vulnerability_ids(vulnerability_ids): """Remove undisired vulnerability id values""" - vulnerability_ids = [x for x in vulnerability_ids if x.strip()] + return [x for x in vulnerability_ids if x.strip()] def save_vulnerability_ids(finding, vulnerability_ids, *, delete_existing: bool = True): - # Remove duplicates + # Remove duplicates and empty/whitespace IDs vulnerability_ids = list(dict.fromkeys(vulnerability_ids)) + vulnerability_ids = sanitize_vulnerability_ids(vulnerability_ids) # Remove old vulnerability ids if requested # Callers can set delete_existing=False when they know there are no existing IDs @@ -1001,12 +1002,10 @@ def save_vulnerability_ids(finding, vulnerability_ids, *, delete_existing: bool if delete_existing: Vulnerability_Id.objects.filter(finding=finding).delete() - # Remove undisired vulnerability ids - sanitize_vulnerability_ids(vulnerability_ids) - # Save new vulnerability ids - # Using bulk create throws Django 50 warnings about unsaved models... - for vulnerability_id in vulnerability_ids: - Vulnerability_Id(finding=finding, vulnerability_id=vulnerability_id).save() + Vulnerability_Id.objects.bulk_create([ + Vulnerability_Id(finding=finding, vulnerability_id=vid) + for vid in vulnerability_ids + ]) # Set CVE if vulnerability_ids: diff --git a/dojo/importers/base_importer.py b/dojo/importers/base_importer.py index d87524185fe..35194c7d351 100644 --- a/dojo/importers/base_importer.py +++ b/dojo/importers/base_importer.py @@ -31,6 +31,7 @@ Test_Import, Test_Import_Finding_Action, Test_Type, + Vulnerability_Id, ) from dojo.notifications.helper import create_notification from dojo.tags.utils import bulk_add_tags_to_instances @@ -77,6 +78,8 @@ def __init__( and will raise a `NotImplemented` exception """ ImporterOptions.__init__(self, *args, **kwargs) + self.pending_vulnerability_ids: list[Vulnerability_Id] = [] + self.pending_vuln_id_deletes: list[int] = [] def check_child_implementation_exception(self): """ @@ -778,21 +781,31 @@ def store_vulnerability_ids( finding: Finding, ) -> Finding: """ - Store vulnerability IDs for a finding. - Reads from finding.unsaved_vulnerability_ids and saves them overwriting existing ones. - - Args: - finding: The finding to store vulnerability IDs for - - Returns: - The finding object - + Accumulate Vulnerability_Id objects for bulk insert at the batch boundary. + Call flush_vulnerability_ids() to persist. """ self.sanitize_vulnerability_ids(finding) - vulnerability_ids_to_process = finding.unsaved_vulnerability_ids or [] - finding_helper.save_vulnerability_ids(finding, vulnerability_ids_to_process, delete_existing=False) + vulnerability_ids_to_process = list(dict.fromkeys(finding.unsaved_vulnerability_ids or [])) + vulnerability_ids_to_process = [x for x in vulnerability_ids_to_process if x.strip()] + self.pending_vulnerability_ids.extend([ + Vulnerability_Id(finding=finding, vulnerability_id=vid) + for vid in vulnerability_ids_to_process + ]) + if vulnerability_ids_to_process: + finding.cve = vulnerability_ids_to_process[0] + else: + finding.cve = None return finding + def flush_vulnerability_ids(self) -> None: + """Delete stale and bulk-insert accumulated Vulnerability_Id objects, then clear buffers.""" + if self.pending_vuln_id_deletes: + Vulnerability_Id.objects.filter(finding_id__in=self.pending_vuln_id_deletes).delete() + self.pending_vuln_id_deletes.clear() + if self.pending_vulnerability_ids: + Vulnerability_Id.objects.bulk_create(self.pending_vulnerability_ids, batch_size=1000) + self.pending_vulnerability_ids.clear() + def process_files( self, finding: Finding, diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 3a920577d2d..d8d825fd732 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -275,6 +275,7 @@ def _process_findings_internal( # If batch is full or we're at the end, persist locations/endpoints and dispatch if len(batch_finding_ids) >= batch_max_size or is_final_finding: self.location_handler.persist() + self.flush_vulnerability_ids() # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -415,6 +416,7 @@ def close_old_findings( ) # Persist any accumulated location/endpoint status changes self.location_handler.persist() + self.flush_vulnerability_ids() # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index aa33c6153b0..15583e332cf 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -23,6 +23,7 @@ Notes, Test, Test_Import, + Vulnerability_Id, ) from dojo.tags import inheritance as tag_inheritance from dojo.tags.inheritance import apply_inherited_tags_for_findings @@ -438,6 +439,7 @@ def _process_findings_internal( # They don't need to be aligned since they optimize different operations. if len(batch_finding_ids) >= dedupe_batch_max_size or is_final: self.location_handler.persist() + self.flush_vulnerability_ids() # Apply parser-supplied tags for this batch before post-processing starts, # so rules/deduplication tasks see the tags already on the findings. bulk_apply_parser_tags(findings_with_parser_tags) @@ -561,6 +563,7 @@ def close_old_findings( mitigated_findings.append(finding) # Persist any accumulated location/endpoint status changes self.location_handler.persist() + self.flush_vulnerability_ids() # push finding groups to jira since we only only want to push whole groups # We dont check if the finding jira sync is applicable quite yet until we can get in the loop # but this is a way to at least make it that far @@ -944,24 +947,17 @@ def reconcile_vulnerability_ids( ) -> Finding: """ Reconcile vulnerability IDs for an existing finding. - Checks if IDs have changed before updating to avoid unnecessary database operations. - Uses prefetched data if available, otherwise fetches efficiently. - - Args: - finding: The existing finding to reconcile vulnerability IDs for. - Must have unsaved_vulnerability_ids set. - - Returns: - The finding object - + Accumulates changes into pending_vuln_id_deletes / pending_vulnerability_ids + for batch flush at the batch boundary via flush_vulnerability_ids(). """ - vulnerability_ids_to_process = finding.unsaved_vulnerability_ids or [] + vulnerability_ids_to_process = list(dict.fromkeys(finding.unsaved_vulnerability_ids or [])) + vulnerability_ids_to_process = [x for x in vulnerability_ids_to_process if x.strip()] # Use prefetched data directly without triggering queries existing_vuln_ids = {v.vulnerability_id for v in finding.vulnerability_id_set.all()} new_vuln_ids = set(vulnerability_ids_to_process) - # Early exit if unchanged + # Early exit if unchanged — no DB work needed if existing_vuln_ids == new_vuln_ids: logger.debug( f"Skipping vulnerability_ids update for finding {finding.id} - " @@ -969,8 +965,16 @@ def reconcile_vulnerability_ids( ) return finding - # Update if changed - finding_helper.save_vulnerability_ids(finding, vulnerability_ids_to_process, delete_existing=True) + # Accumulate delete + insert for batch flush + self.pending_vuln_id_deletes.append(finding.id) + self.pending_vulnerability_ids.extend([ + Vulnerability_Id(finding=finding, vulnerability_id=vid) + for vid in vulnerability_ids_to_process + ]) + if vulnerability_ids_to_process: + finding.cve = vulnerability_ids_to_process[0] + else: + finding.cve = None return finding def finding_post_processing( diff --git a/unittests/test_finding_helper.py b/unittests/test_finding_helper.py index fa6fd2d9ea5..f6d9e747ff2 100644 --- a/unittests/test_finding_helper.py +++ b/unittests/test_finding_helper.py @@ -220,8 +220,8 @@ class TestSaveVulnerabilityIds(DojoTestCase): @patch("dojo.finding.helper.Vulnerability_Id.objects.filter") @patch("django.db.models.query.QuerySet.delete") - @patch("dojo.finding.helper.Vulnerability_Id.save") - def test_save_vulnerability_ids(self, save_mock, delete_mock, filter_mock): + @patch("dojo.finding.helper.Vulnerability_Id.objects.bulk_create") + def test_save_vulnerability_ids(self, bulk_create_mock, delete_mock, filter_mock): finding = Finding() new_vulnerability_ids = ["REF-1", "REF-2", "REF-2"] filter_mock.return_value = Vulnerability_Id.objects.none() @@ -230,7 +230,10 @@ def test_save_vulnerability_ids(self, save_mock, delete_mock, filter_mock): filter_mock.assert_called_with(finding=finding) delete_mock.assert_called_once() - self.assertEqual(save_mock.call_count, 2) + bulk_create_mock.assert_called_once() + # Duplicates are removed: REF-1 and REF-2 only + created_objects = bulk_create_mock.call_args[0][0] + self.assertEqual(2, len(created_objects)) self.assertEqual("REF-1", finding.cve) @patch("dojo.models.Finding_Template.save") diff --git a/unittests/test_importers_importer.py b/unittests/test_importers_importer.py index aa14ace8beb..b5ac6a276e8 100644 --- a/unittests/test_importers_importer.py +++ b/unittests/test_importers_importer.py @@ -803,14 +803,15 @@ def create_default_data(self): } def test_handle_vulnerability_ids_references_and_cve(self): - # Why doesn't this test use the test db and query for one? vulnerability_ids = ["CVE", "REF-1", "REF-2"] finding = Finding() finding.unsaved_vulnerability_ids = vulnerability_ids finding.test = self.test finding.reporter = self.testuser finding.save() - DefaultImporter(**self.importer_data).store_vulnerability_ids(finding) + importer = DefaultImporter(**self.importer_data) + importer.store_vulnerability_ids(finding) + importer.flush_vulnerability_ids() self.assertEqual("CVE", finding.vulnerability_ids[0]) self.assertEqual("CVE", finding.cve) @@ -827,7 +828,9 @@ def test_handle_no_vulnerability_ids_references_and_cve(self): finding.save() finding.unsaved_vulnerability_ids = vulnerability_ids - DefaultImporter(**self.importer_data).store_vulnerability_ids(finding) + importer = DefaultImporter(**self.importer_data) + importer.store_vulnerability_ids(finding) + importer.flush_vulnerability_ids() self.assertEqual("CVE", finding.vulnerability_ids[0]) self.assertEqual("CVE", finding.cve) @@ -841,7 +844,9 @@ def test_handle_vulnerability_ids_references_and_no_cve(self): finding.reporter = self.testuser finding.save() finding.unsaved_vulnerability_ids = vulnerability_ids - DefaultImporter(**self.importer_data).store_vulnerability_ids(finding) + importer = DefaultImporter(**self.importer_data) + importer.store_vulnerability_ids(finding) + importer.flush_vulnerability_ids() self.assertEqual("REF-1", finding.vulnerability_ids[0]) self.assertEqual("REF-1", finding.cve) @@ -854,7 +859,9 @@ def test_no_handle_vulnerability_ids_references_and_no_cve(self): finding.test = self.test finding.reporter = self.testuser finding.save() - DefaultImporter(**self.importer_data).store_vulnerability_ids(finding) + importer = DefaultImporter(**self.importer_data) + importer.store_vulnerability_ids(finding) + importer.flush_vulnerability_ids() self.assertEqual(finding.cve, None) self.assertEqual(finding.unsaved_vulnerability_ids, None) self.assertEqual(finding.vulnerability_ids, []) @@ -880,7 +887,9 @@ def test_clear_vulnerability_ids_on_empty_list(self): # Process with empty list - should clear all IDs finding.unsaved_vulnerability_ids = [] - DefaultReImporter(test=self.test, environment=self.importer_data["environment"], scan_type=self.importer_data["scan_type"]).reconcile_vulnerability_ids(finding) + reimporter = DefaultReImporter(test=self.test, environment=self.importer_data["environment"], scan_type=self.importer_data["scan_type"]) + reimporter.reconcile_vulnerability_ids(finding) + reimporter.flush_vulnerability_ids() # Save the finding to persist the cve=None change finding.save() @@ -917,7 +926,9 @@ def test_change_vulnerability_ids_on_reimport(self): # Process with different IDs - should replace old IDs new_vulnerability_ids = ["CVE-2021-9999", "GHSA-xxxx-yyyy"] finding.unsaved_vulnerability_ids = new_vulnerability_ids - DefaultReImporter(test=self.test, environment=self.importer_data["environment"], scan_type=self.importer_data["scan_type"]).reconcile_vulnerability_ids(finding) + reimporter = DefaultReImporter(test=self.test, environment=self.importer_data["environment"], scan_type=self.importer_data["scan_type"]) + reimporter.reconcile_vulnerability_ids(finding) + reimporter.flush_vulnerability_ids() # Save the finding to persist the cve change finding.save() @@ -933,3 +944,90 @@ def test_change_vulnerability_ids_on_reimport(self): vuln_ids = list(Vulnerability_Id.objects.filter(finding=finding).values_list("vulnerability_id", flat=True)) self.assertEqual(set(new_vulnerability_ids), set(vuln_ids)) finding.delete() + + def test_reconcile_vulnerability_ids_cross_finding_batch(self): + """Multiple findings accumulated before flush — one delete+insert pair per changed finding.""" + reimporter = DefaultReImporter(test=self.test, environment=self.importer_data["environment"], scan_type=self.importer_data["scan_type"]) + + # finding_a: IDs change (CVE-A → CVE-B) + finding_a = Finding(test=self.test, reporter=self.testuser) + finding_a.save() + Vulnerability_Id.objects.create(finding=finding_a, vulnerability_id="CVE-A-OLD") + finding_a.cve = "CVE-A-OLD" + finding_a.save() + + # finding_b: IDs change (CVE-B1, CVE-B2 → CVE-B-NEW) + finding_b = Finding(test=self.test, reporter=self.testuser) + finding_b.save() + Vulnerability_Id.objects.create(finding=finding_b, vulnerability_id="CVE-B1") + Vulnerability_Id.objects.create(finding=finding_b, vulnerability_id="CVE-B2") + finding_b.cve = "CVE-B1" + finding_b.save() + + # finding_c: IDs unchanged — should not appear in delete/insert buffers + finding_c = Finding(test=self.test, reporter=self.testuser) + finding_c.save() + Vulnerability_Id.objects.create(finding=finding_c, vulnerability_id="CVE-C-SAME") + finding_c.cve = "CVE-C-SAME" + finding_c.save() + + finding_a.unsaved_vulnerability_ids = ["CVE-A-NEW"] + finding_b.unsaved_vulnerability_ids = ["CVE-B-NEW"] + finding_c.unsaved_vulnerability_ids = ["CVE-C-SAME"] + + # Accumulate all three before any flush + reimporter.reconcile_vulnerability_ids(finding_a) + reimporter.reconcile_vulnerability_ids(finding_b) + reimporter.reconcile_vulnerability_ids(finding_c) + + # pending_vuln_id_deletes only contains changed findings, not finding_c + self.assertIn(finding_a.id, reimporter.pending_vuln_id_deletes) + self.assertIn(finding_b.id, reimporter.pending_vuln_id_deletes) + self.assertNotIn(finding_c.id, reimporter.pending_vuln_id_deletes) + self.assertEqual(2, len(reimporter.pending_vulnerability_ids)) + + # Old IDs still in DB (not yet deleted) + self.assertEqual(1, Vulnerability_Id.objects.filter(finding=finding_a).count()) + self.assertEqual(2, Vulnerability_Id.objects.filter(finding=finding_b).count()) + + reimporter.flush_vulnerability_ids() + + # Buffers cleared + self.assertEqual([], reimporter.pending_vuln_id_deletes) + self.assertEqual([], reimporter.pending_vulnerability_ids) + + # finding_a: old deleted, new inserted + vuln_ids_a = list(Vulnerability_Id.objects.filter(finding=finding_a).values_list("vulnerability_id", flat=True)) + self.assertEqual(["CVE-A-NEW"], vuln_ids_a) + self.assertEqual("CVE-A-NEW", finding_a.cve) + + # finding_b: both old deleted, new inserted + vuln_ids_b = list(Vulnerability_Id.objects.filter(finding=finding_b).values_list("vulnerability_id", flat=True)) + self.assertEqual(["CVE-B-NEW"], vuln_ids_b) + self.assertEqual("CVE-B-NEW", finding_b.cve) + + # finding_c: unchanged — IDs untouched + vuln_ids_c = list(Vulnerability_Id.objects.filter(finding=finding_c).values_list("vulnerability_id", flat=True)) + self.assertEqual(["CVE-C-SAME"], vuln_ids_c) + + finding_a.delete() + finding_b.delete() + finding_c.delete() + + def test_reconcile_vulnerability_ids_unchanged_no_db_write(self): + """Early-exit path: unchanged IDs never touch pending buffers.""" + reimporter = DefaultReImporter(test=self.test, environment=self.importer_data["environment"], scan_type=self.importer_data["scan_type"]) + + finding = Finding(test=self.test, reporter=self.testuser) + finding.save() + Vulnerability_Id.objects.create(finding=finding, vulnerability_id="CVE-2020-1234") + finding.cve = "CVE-2020-1234" + finding.save() + + finding.unsaved_vulnerability_ids = ["CVE-2020-1234"] + reimporter.reconcile_vulnerability_ids(finding) + + self.assertEqual([], reimporter.pending_vuln_id_deletes) + self.assertEqual([], reimporter.pending_vulnerability_ids) + + finding.delete() diff --git a/unittests/test_importers_performance.py b/unittests/test_importers_performance.py index dc82f28114d..a6e26b37f06 100644 --- a/unittests/test_importers_performance.py +++ b/unittests/test_importers_performance.py @@ -20,7 +20,6 @@ import logging from contextlib import contextmanager -from unittest import skip from unittest.mock import patch from crum import impersonate @@ -275,11 +274,6 @@ def _import_reimport_performance( self.assertGreater(len_closed_findings4, 0, "Step 4 (empty reimport with close_old_findings=True) should close findings") -@skip("Re-baseline pending: Track B legacy authorization reduces auth-layer query " - "overhead (no per-action role-permission lookups, simpler permission_to_action " - "dispatch). Expected query counts here were calibrated under RBAC and are " - "consistently 1-7 queries higher than legacy actual. Re-baseline with a fresh " - "calibration run after the upstream merge.") @tag("performance") @skip_unless_v2 class TestDojoImporterPerformanceSmall(TestDojoImporterPerformanceBase): @@ -349,13 +343,13 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - expected_num_queries1=171, + expected_num_queries1=156, expected_num_async_tasks1=2, - expected_num_queries2=124, + expected_num_queries2=121, expected_num_async_tasks2=1, - expected_num_queries3=29, + expected_num_queries3=28, expected_num_async_tasks3=1, - expected_num_queries4=100, + expected_num_queries4=99, expected_num_async_tasks4=0, ) @@ -373,13 +367,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=187, + expected_num_queries1=170, expected_num_async_tasks1=2, - expected_num_queries2=132, + expected_num_queries2=129, expected_num_async_tasks2=1, - expected_num_queries3=37, + expected_num_queries3=36, expected_num_async_tasks3=1, - expected_num_queries4=100, + expected_num_queries4=99, expected_num_async_tasks4=0, ) @@ -398,13 +392,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=197, + expected_num_queries1=180, expected_num_async_tasks1=4, - expected_num_queries2=142, + expected_num_queries2=139, expected_num_async_tasks2=3, - expected_num_queries3=44, + expected_num_queries3=43, expected_num_async_tasks3=3, - expected_num_queries4=109, + expected_num_queries4=108, expected_num_async_tasks4=2, ) @@ -530,9 +524,9 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=110, + expected_num_queries1=92, expected_num_async_tasks1=2, - expected_num_queries2=90, + expected_num_queries2=72, expected_num_async_tasks2=2, check_duplicates=False, # Async mode - deduplication happens later ) @@ -551,18 +545,15 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=126, + expected_num_queries1=106, expected_num_async_tasks1=2, - expected_num_queries2=107, + expected_num_queries2=87, expected_num_async_tasks2=2, ) @tag("performance") @override_settings(V3_FEATURE_LOCATIONS=True) -@skip("Re-baseline pending: same RBAC→legacy query-count drift as " - "TestDojoImporterPerformanceSmall. See that class's skip note for the " - "rationale.") class TestDojoImporterPerformanceSmallLocations(TestDojoImporterPerformanceBase): r""" @@ -642,13 +633,13 @@ def test_import_reimport_reimport_performance_pghistory_async(self): configure_pghistory_triggers() self._import_reimport_performance( - expected_num_queries1=178, + expected_num_queries1=163, expected_num_async_tasks1=2, - expected_num_queries2=133, + expected_num_queries2=130, expected_num_async_tasks2=1, - expected_num_queries3=37, + expected_num_queries3=36, expected_num_async_tasks3=1, - expected_num_queries4=101, + expected_num_queries4=100, expected_num_async_tasks4=0, ) @@ -666,13 +657,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._import_reimport_performance( - expected_num_queries1=196, + expected_num_queries1=179, expected_num_async_tasks1=2, - expected_num_queries2=143, + expected_num_queries2=140, expected_num_async_tasks2=1, - expected_num_queries3=47, + expected_num_queries3=46, expected_num_async_tasks3=1, - expected_num_queries4=101, + expected_num_queries4=100, expected_num_async_tasks4=0, ) @@ -691,13 +682,13 @@ def test_import_reimport_reimport_performance_pghistory_no_async_with_product_gr self.system_settings(enable_product_grade=True) self._import_reimport_performance( - expected_num_queries1=209, + expected_num_queries1=192, expected_num_async_tasks1=4, - expected_num_queries2=156, + expected_num_queries2=153, expected_num_async_tasks2=3, - expected_num_queries3=54, + expected_num_queries3=53, expected_num_async_tasks3=3, - expected_num_queries4=113, + expected_num_queries4=112, expected_num_async_tasks4=2, ) @@ -798,9 +789,9 @@ def test_deduplication_performance_pghistory_async(self): self.system_settings(enable_deduplication=True) self._deduplication_performance( - expected_num_queries1=117, + expected_num_queries1=99, expected_num_async_tasks1=2, - expected_num_queries2=93, + expected_num_queries2=75, expected_num_async_tasks2=2, check_duplicates=False, # Async mode - deduplication happens later ) @@ -818,8 +809,8 @@ def test_deduplication_performance_pghistory_no_async(self): testuser.usercontactinfo.save() self._deduplication_performance( - expected_num_queries1=135, + expected_num_queries1=115, expected_num_async_tasks1=2, - expected_num_queries2=218, + expected_num_queries2=198, expected_num_async_tasks2=2, )