Skip to content

Commit b79a0f8

Browse files
committed
Refactor output handling to use shared alert selection
Signed-off-by: lelia <lelia@socket.dev>
1 parent 3cf14b6 commit b79a0f8

File tree

1 file changed

+28
-152
lines changed

1 file changed

+28
-152
lines changed

socketsecurity/output.py

Lines changed: 28 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import json
22
import logging
33
from pathlib import Path
4-
from typing import Any, Dict, Optional, List, Set, Tuple
4+
from typing import Any, Dict, Optional
55
from .core.messages import Messages
66
from .core.classes import Diff, Issue
77
from .config import CliConfig
88
from socketsecurity.plugins.manager import PluginManager
9-
from socketsecurity.core.helper.socket_facts_loader import (
10-
load_socket_facts,
11-
get_components_with_vulnerabilities,
12-
convert_to_alerts,
9+
from socketsecurity.core.alert_selection import (
10+
clone_diff_with_selected_alerts,
11+
filter_alerts_by_reachability,
12+
load_components_with_alerts,
13+
select_diff_alerts,
1314
)
1415
from socketdev import socketdev
1516

@@ -22,131 +23,6 @@ def __init__(self, config: CliConfig, sdk: socketdev):
2223
self.config = config
2324
self.logger = logging.getLogger("socketcli")
2425

25-
@staticmethod
26-
def _normalize_purl(purl: str) -> str:
27-
if not purl:
28-
return ""
29-
30-
normalized = purl.strip().lower().replace("%40", "@")
31-
if normalized.startswith("pkg:"):
32-
normalized = normalized[4:]
33-
return normalized
34-
35-
@staticmethod
36-
def _normalize_vuln_id(vuln_id: str) -> str:
37-
if not vuln_id:
38-
return ""
39-
return vuln_id.strip().upper()
40-
41-
@staticmethod
42-
def _normalize_pkg_key(pkg_type: str, pkg_name: str, pkg_version: str) -> Tuple[str, str, str]:
43-
return (
44-
(pkg_type or "").strip().lower(),
45-
(pkg_name or "").strip().lower(),
46-
(pkg_version or "").strip().lower(),
47-
)
48-
49-
@staticmethod
50-
def _extract_issue_vuln_ids(issue: Issue) -> Set[str]:
51-
ids: Set[str] = set()
52-
props = getattr(issue, "props", None) or {}
53-
for key in ("ghsaId", "ghsa_id", "cveId", "cve_id"):
54-
value = props.get(key)
55-
if isinstance(value, str) and value.strip():
56-
ids.add(OutputHandler._normalize_vuln_id(value))
57-
return ids
58-
59-
def _load_components_with_alerts(self) -> Optional[List[Dict[str, Any]]]:
60-
facts_file = self.config.reach_output_file or ".socket.facts.json"
61-
facts_file_path = str(Path(self.config.target_path or ".") / facts_file)
62-
facts_data = load_socket_facts(facts_file_path)
63-
if not facts_data:
64-
return None
65-
66-
components = get_components_with_vulnerabilities(facts_data)
67-
return convert_to_alerts(components)
68-
69-
def _build_reachability_index(self) -> Optional[Tuple[Dict[str, Set[str]], Dict[Tuple[str, str, str], Set[str]]]]:
70-
components_with_alerts = self._load_components_with_alerts()
71-
if not components_with_alerts:
72-
self.logger.warning(
73-
"Unable to load reachability facts; falling back to blocking-based SARIF filter"
74-
)
75-
return None
76-
77-
reachable_by_purl: Dict[str, Set[str]] = {}
78-
reachable_by_pkg: Dict[Tuple[str, str, str], Set[str]] = {}
79-
80-
for component in components_with_alerts:
81-
purl = self._normalize_purl(component.get("alerts", [{}])[0].get("props", {}).get("purl", ""))
82-
pkg_type = component.get("type", "")
83-
pkg_version = component.get("version", "")
84-
namespace = (component.get("namespace") or "").strip()
85-
name = (component.get("name") or component.get("id") or "").strip()
86-
87-
pkg_names: Set[str] = {name}
88-
if namespace:
89-
pkg_names.add(f"{namespace}/{name}")
90-
91-
for alert in component.get("alerts", []):
92-
props = alert.get("props", {}) or {}
93-
if props.get("reachability") != "reachable":
94-
continue
95-
96-
vuln_ids = {
97-
self._normalize_vuln_id(props.get("ghsaId", "")),
98-
self._normalize_vuln_id(props.get("cveId", "")),
99-
}
100-
vuln_ids = {v for v in vuln_ids if v}
101-
if not vuln_ids:
102-
continue
103-
104-
if purl:
105-
if purl not in reachable_by_purl:
106-
reachable_by_purl[purl] = set()
107-
reachable_by_purl[purl].update(vuln_ids)
108-
109-
for pkg_name in pkg_names:
110-
pkg_key = self._normalize_pkg_key(pkg_type, pkg_name, pkg_version)
111-
if pkg_key not in reachable_by_pkg:
112-
reachable_by_pkg[pkg_key] = set()
113-
reachable_by_pkg[pkg_key].update(vuln_ids)
114-
115-
return reachable_by_purl, reachable_by_pkg
116-
117-
def _is_alert_reachable(
118-
self,
119-
alert: Issue,
120-
reachable_by_purl: Dict[str, Set[str]],
121-
reachable_by_pkg: Dict[Tuple[str, str, str], Set[str]],
122-
) -> bool:
123-
alert_ids = self._extract_issue_vuln_ids(alert)
124-
alert_purl = self._normalize_purl(getattr(alert, "purl", ""))
125-
pkg_key = self._normalize_pkg_key(
126-
getattr(alert, "pkg_type", ""),
127-
getattr(alert, "pkg_name", ""),
128-
getattr(alert, "pkg_version", ""),
129-
)
130-
131-
if alert_ids:
132-
if alert_purl and alert_purl in reachable_by_purl and alert_ids.intersection(reachable_by_purl[alert_purl]):
133-
return True
134-
if pkg_key in reachable_by_pkg and alert_ids.intersection(reachable_by_pkg[pkg_key]):
135-
return True
136-
return False
137-
138-
if alert_purl and alert_purl in reachable_by_purl:
139-
return True
140-
return pkg_key in reachable_by_pkg
141-
142-
def _filter_sarif_reachable_alerts(self, alerts: List[Issue]) -> List[Issue]:
143-
reachability_index = self._build_reachability_index()
144-
if not reachability_index:
145-
return [a for a in alerts if getattr(a, "error", False)]
146-
147-
reachable_by_purl, reachable_by_pkg = reachability_index
148-
return [a for a in alerts if self._is_alert_reachable(a, reachable_by_purl, reachable_by_pkg)]
149-
15026
def handle_output(self, diff_report: Diff) -> None:
15127
"""Main output handler that determines output format"""
15228
# Determine which formats to output
@@ -231,7 +107,8 @@ def return_exit_code(self, diff_report: Diff) -> int:
231107

232108
def output_console_comments(self, diff_report: Diff, sbom_file_name: Optional[str] = None) -> None:
233109
"""Outputs formatted console comments"""
234-
has_new_alerts = len(diff_report.new_alerts) > 0
110+
selected_alerts = select_diff_alerts(diff_report, strict_blocking=self.config.strict_blocking)
111+
has_new_alerts = len(selected_alerts) > 0
235112
has_unchanged_alerts = (
236113
self.config.strict_blocking and
237114
hasattr(diff_report, 'unchanged_alerts') and
@@ -252,7 +129,8 @@ def output_console_comments(self, diff_report: Diff, sbom_file_name: Optional[st
252129
unchanged_blocking = sum(1 for issue in diff_report.unchanged_alerts if issue.error)
253130
unchanged_warning = sum(1 for issue in diff_report.unchanged_alerts if issue.warn)
254131

255-
console_security_comment = Messages.create_console_security_alert_table(diff_report)
132+
selected_diff = clone_diff_with_selected_alerts(diff_report, selected_alerts)
133+
console_security_comment = Messages.create_console_security_alert_table(selected_diff)
256134

257135
# Build status message
258136
self.logger.info("Security issues detected by Socket Security:")
@@ -270,7 +148,9 @@ def output_console_comments(self, diff_report: Diff, sbom_file_name: Optional[st
270148

271149
def output_console_json(self, diff_report: Diff, sbom_file_name: Optional[str] = None) -> None:
272150
"""Outputs JSON formatted results"""
273-
console_security_comment = Messages.create_security_comment_json(diff_report)
151+
selected_alerts = select_diff_alerts(diff_report, strict_blocking=self.config.strict_blocking)
152+
selected_diff = clone_diff_with_selected_alerts(diff_report, selected_alerts)
153+
console_security_comment = Messages.create_security_comment_json(selected_diff)
274154
self.save_sbom_file(diff_report, sbom_file_name)
275155
self.logger.info(json.dumps(console_security_comment))
276156

@@ -289,11 +169,12 @@ def output_console_sarif(self, diff_report: Diff, sbom_file_name: Optional[str]
289169
sarif_grouping = "instance"
290170
if sarif_reachability not in {"all", "reachable", "potentially", "reachable-or-potentially"}:
291171
sarif_reachability = "all"
292-
if getattr(self.config, "sarif_reachable_only", False) is True:
293-
sarif_reachability = "reachable"
294172
if diff_report.id != "NO_DIFF_RAN" or sarif_scope == "full":
295173
if sarif_scope == "full":
296-
components_with_alerts = self._load_components_with_alerts()
174+
components_with_alerts = load_components_with_alerts(
175+
self.config.target_path,
176+
self.config.reach_output_file,
177+
)
297178
if not components_with_alerts:
298179
self.logger.error(
299180
"Unable to generate full-scope SARIF: .socket.facts.json missing or invalid"
@@ -305,24 +186,19 @@ def output_console_sarif(self, diff_report: Diff, sbom_file_name: Optional[str]
305186
grouping=sarif_grouping,
306187
)
307188
else:
308-
if sarif_reachability == "reachable":
309-
filtered_alerts = self._filter_sarif_reachable_alerts(diff_report.new_alerts)
310-
diff_report = Diff(
311-
new_alerts=filtered_alerts,
312-
diff_url=getattr(diff_report, "diff_url", ""),
313-
new_packages=getattr(diff_report, "new_packages", []),
314-
removed_packages=getattr(diff_report, "removed_packages", []),
315-
packages=getattr(diff_report, "packages", {}),
316-
)
317-
diff_report.id = "filtered"
318-
elif sarif_reachability != "all":
319-
self.logger.warning(
320-
"Reachability filter '%s' is only supported in full SARIF scope; output is unfiltered in diff scope",
321-
sarif_reachability,
322-
)
189+
selected_alerts = select_diff_alerts(diff_report, strict_blocking=self.config.strict_blocking)
190+
filtered_alerts = filter_alerts_by_reachability(
191+
selected_alerts,
192+
sarif_reachability,
193+
self.config.target_path,
194+
self.config.reach_output_file,
195+
logger=self.logger,
196+
fallback_to_blocking_for_reachable=True,
197+
)
198+
selected_diff = clone_diff_with_selected_alerts(diff_report, filtered_alerts)
323199

324200
# Generate the SARIF structure using Messages
325-
console_security_comment = Messages.create_security_comment_sarif(diff_report)
201+
console_security_comment = Messages.create_security_comment_sarif(selected_diff)
326202
self.save_sbom_file(diff_report, sbom_file_name)
327203
# Avoid flooding logs for full-scope SARIF when writing to file.
328204
if not (sarif_scope == "full" and self.config.sarif_file):

0 commit comments

Comments
 (0)