diff --git a/vulnerabilities/fetch.py b/vulnerabilities/fetch.py index b922127c..7f286ed0 100644 --- a/vulnerabilities/fetch.py +++ b/vulnerabilities/fetch.py @@ -50,15 +50,17 @@ def fetch_from_vulnerablecode(dataspace, batch_size, update, timeout, log_func=N timeout=timeout, log_func=log_func, ) + run_time = timer() - start_time if log_func: log_func(f"+ Created {intcomma(results.get('created', 0))} vulnerabilities") log_func(f"+ Updated {intcomma(results.get('updated', 0))} vulnerabilities") log_func(f"Completed in {humanize_time(run_time)}") - dataspace.vulnerabilities_updated_at = timezone.now() - dataspace.save(update_fields=["vulnerabilities_updated_at"]) - log_func("Dataspace.vulnerabilities_updated_at updated") + if results: + dataspace.vulnerabilities_updated_at = timezone.now() + dataspace.save(update_fields=["vulnerabilities_updated_at"]) + log_func("Dataspace.vulnerabilities_updated_at updated") def fetch_for_packages( @@ -66,12 +68,13 @@ def fetch_for_packages( ): from product_portfolio.models import ProductPackage + results = {"created": 0, "updated": 0} + object_count = queryset.count() if object_count < 1: - return + return results vulnerablecode = VulnerableCode(dataspace) - results = {"created": 0, "updated": 0} for index, batch in enumerate(chunked_queryset(queryset, batch_size), start=1): if log_func: diff --git a/vulnerabilities/tests/test_fetch.py b/vulnerabilities/tests/test_fetch.py index 108f5f2c..239931c2 100644 --- a/vulnerabilities/tests/test_fetch.py +++ b/vulnerabilities/tests/test_fetch.py @@ -58,6 +58,20 @@ def test_vulnerabilities_fetch_from_vulnerablecode( self.dataspace.refresh_from_db() self.assertIsNotNone(self.dataspace.vulnerabilities_updated_at) + buffer = io.StringIO() + dataspace_empty = Dataspace.objects.create(name="empty") + mock_fetch_for_packages.return_value = {} + fetch_from_vulnerablecode( + dataspace_empty, batch_size=1, update=True, timeout=None, log_func=buffer.write + ) + expected = ( + "0 Packages in the queue." + "+ Created 0 vulnerabilities" + "+ Updated 0 vulnerabilities" + "Completed in 0 seconds" + ) + self.assertEqual(expected, buffer.getvalue()) + @mock.patch("dejacode_toolkit.vulnerablecode.VulnerableCode.bulk_search_by_purl") def test_vulnerabilities_fetch_for_packages(self, mock_bulk_search_by_purl): buffer = io.StringIO()