11import json
22import logging
33from pathlib import Path
4- from typing import Any , Dict , Optional , List , Set , Tuple
4+ from typing import Any , Dict , Optional
55from .core .messages import Messages
66from .core .classes import Diff , Issue
77from .config import CliConfig
88from socketsecurity .plugins .manager import PluginManager
9- from socketsecurity .core .helper .socket_facts_loader import (
10- load_socket_facts ,
11- get_components_with_vulnerabilities ,
12- convert_to_alerts ,
9+ from socketsecurity .core .alert_selection import (
10+ clone_diff_with_selected_alerts ,
11+ filter_alerts_by_reachability ,
12+ load_components_with_alerts ,
13+ select_diff_alerts ,
1314)
1415from socketdev import socketdev
1516
@@ -22,131 +23,6 @@ def __init__(self, config: CliConfig, sdk: socketdev):
2223 self .config = config
2324 self .logger = logging .getLogger ("socketcli" )
2425
25- @staticmethod
26- def _normalize_purl (purl : str ) -> str :
27- if not purl :
28- return ""
29-
30- normalized = purl .strip ().lower ().replace ("%40" , "@" )
31- if normalized .startswith ("pkg:" ):
32- normalized = normalized [4 :]
33- return normalized
34-
35- @staticmethod
36- def _normalize_vuln_id (vuln_id : str ) -> str :
37- if not vuln_id :
38- return ""
39- return vuln_id .strip ().upper ()
40-
41- @staticmethod
42- def _normalize_pkg_key (pkg_type : str , pkg_name : str , pkg_version : str ) -> Tuple [str , str , str ]:
43- return (
44- (pkg_type or "" ).strip ().lower (),
45- (pkg_name or "" ).strip ().lower (),
46- (pkg_version or "" ).strip ().lower (),
47- )
48-
49- @staticmethod
50- def _extract_issue_vuln_ids (issue : Issue ) -> Set [str ]:
51- ids : Set [str ] = set ()
52- props = getattr (issue , "props" , None ) or {}
53- for key in ("ghsaId" , "ghsa_id" , "cveId" , "cve_id" ):
54- value = props .get (key )
55- if isinstance (value , str ) and value .strip ():
56- ids .add (OutputHandler ._normalize_vuln_id (value ))
57- return ids
58-
59- def _load_components_with_alerts (self ) -> Optional [List [Dict [str , Any ]]]:
60- facts_file = self .config .reach_output_file or ".socket.facts.json"
61- facts_file_path = str (Path (self .config .target_path or "." ) / facts_file )
62- facts_data = load_socket_facts (facts_file_path )
63- if not facts_data :
64- return None
65-
66- components = get_components_with_vulnerabilities (facts_data )
67- return convert_to_alerts (components )
68-
69- def _build_reachability_index (self ) -> Optional [Tuple [Dict [str , Set [str ]], Dict [Tuple [str , str , str ], Set [str ]]]]:
70- components_with_alerts = self ._load_components_with_alerts ()
71- if not components_with_alerts :
72- self .logger .warning (
73- "Unable to load reachability facts; falling back to blocking-based SARIF filter"
74- )
75- return None
76-
77- reachable_by_purl : Dict [str , Set [str ]] = {}
78- reachable_by_pkg : Dict [Tuple [str , str , str ], Set [str ]] = {}
79-
80- for component in components_with_alerts :
81- purl = self ._normalize_purl (component .get ("alerts" , [{}])[0 ].get ("props" , {}).get ("purl" , "" ))
82- pkg_type = component .get ("type" , "" )
83- pkg_version = component .get ("version" , "" )
84- namespace = (component .get ("namespace" ) or "" ).strip ()
85- name = (component .get ("name" ) or component .get ("id" ) or "" ).strip ()
86-
87- pkg_names : Set [str ] = {name }
88- if namespace :
89- pkg_names .add (f"{ namespace } /{ name } " )
90-
91- for alert in component .get ("alerts" , []):
92- props = alert .get ("props" , {}) or {}
93- if props .get ("reachability" ) != "reachable" :
94- continue
95-
96- vuln_ids = {
97- self ._normalize_vuln_id (props .get ("ghsaId" , "" )),
98- self ._normalize_vuln_id (props .get ("cveId" , "" )),
99- }
100- vuln_ids = {v for v in vuln_ids if v }
101- if not vuln_ids :
102- continue
103-
104- if purl :
105- if purl not in reachable_by_purl :
106- reachable_by_purl [purl ] = set ()
107- reachable_by_purl [purl ].update (vuln_ids )
108-
109- for pkg_name in pkg_names :
110- pkg_key = self ._normalize_pkg_key (pkg_type , pkg_name , pkg_version )
111- if pkg_key not in reachable_by_pkg :
112- reachable_by_pkg [pkg_key ] = set ()
113- reachable_by_pkg [pkg_key ].update (vuln_ids )
114-
115- return reachable_by_purl , reachable_by_pkg
116-
117- def _is_alert_reachable (
118- self ,
119- alert : Issue ,
120- reachable_by_purl : Dict [str , Set [str ]],
121- reachable_by_pkg : Dict [Tuple [str , str , str ], Set [str ]],
122- ) -> bool :
123- alert_ids = self ._extract_issue_vuln_ids (alert )
124- alert_purl = self ._normalize_purl (getattr (alert , "purl" , "" ))
125- pkg_key = self ._normalize_pkg_key (
126- getattr (alert , "pkg_type" , "" ),
127- getattr (alert , "pkg_name" , "" ),
128- getattr (alert , "pkg_version" , "" ),
129- )
130-
131- if alert_ids :
132- if alert_purl and alert_purl in reachable_by_purl and alert_ids .intersection (reachable_by_purl [alert_purl ]):
133- return True
134- if pkg_key in reachable_by_pkg and alert_ids .intersection (reachable_by_pkg [pkg_key ]):
135- return True
136- return False
137-
138- if alert_purl and alert_purl in reachable_by_purl :
139- return True
140- return pkg_key in reachable_by_pkg
141-
142- def _filter_sarif_reachable_alerts (self , alerts : List [Issue ]) -> List [Issue ]:
143- reachability_index = self ._build_reachability_index ()
144- if not reachability_index :
145- return [a for a in alerts if getattr (a , "error" , False )]
146-
147- reachable_by_purl , reachable_by_pkg = reachability_index
148- return [a for a in alerts if self ._is_alert_reachable (a , reachable_by_purl , reachable_by_pkg )]
149-
15026 def handle_output (self , diff_report : Diff ) -> None :
15127 """Main output handler that determines output format"""
15228 # Determine which formats to output
@@ -231,7 +107,8 @@ def return_exit_code(self, diff_report: Diff) -> int:
231107
232108 def output_console_comments (self , diff_report : Diff , sbom_file_name : Optional [str ] = None ) -> None :
233109 """Outputs formatted console comments"""
234- has_new_alerts = len (diff_report .new_alerts ) > 0
110+ selected_alerts = select_diff_alerts (diff_report , strict_blocking = self .config .strict_blocking )
111+ has_new_alerts = len (selected_alerts ) > 0
235112 has_unchanged_alerts = (
236113 self .config .strict_blocking and
237114 hasattr (diff_report , 'unchanged_alerts' ) and
@@ -252,7 +129,8 @@ def output_console_comments(self, diff_report: Diff, sbom_file_name: Optional[st
252129 unchanged_blocking = sum (1 for issue in diff_report .unchanged_alerts if issue .error )
253130 unchanged_warning = sum (1 for issue in diff_report .unchanged_alerts if issue .warn )
254131
255- console_security_comment = Messages .create_console_security_alert_table (diff_report )
132+ selected_diff = clone_diff_with_selected_alerts (diff_report , selected_alerts )
133+ console_security_comment = Messages .create_console_security_alert_table (selected_diff )
256134
257135 # Build status message
258136 self .logger .info ("Security issues detected by Socket Security:" )
@@ -270,7 +148,9 @@ def output_console_comments(self, diff_report: Diff, sbom_file_name: Optional[st
270148
271149 def output_console_json (self , diff_report : Diff , sbom_file_name : Optional [str ] = None ) -> None :
272150 """Outputs JSON formatted results"""
273- console_security_comment = Messages .create_security_comment_json (diff_report )
151+ selected_alerts = select_diff_alerts (diff_report , strict_blocking = self .config .strict_blocking )
152+ selected_diff = clone_diff_with_selected_alerts (diff_report , selected_alerts )
153+ console_security_comment = Messages .create_security_comment_json (selected_diff )
274154 self .save_sbom_file (diff_report , sbom_file_name )
275155 self .logger .info (json .dumps (console_security_comment ))
276156
@@ -289,11 +169,12 @@ def output_console_sarif(self, diff_report: Diff, sbom_file_name: Optional[str]
289169 sarif_grouping = "instance"
290170 if sarif_reachability not in {"all" , "reachable" , "potentially" , "reachable-or-potentially" }:
291171 sarif_reachability = "all"
292- if getattr (self .config , "sarif_reachable_only" , False ) is True :
293- sarif_reachability = "reachable"
294172 if diff_report .id != "NO_DIFF_RAN" or sarif_scope == "full" :
295173 if sarif_scope == "full" :
296- components_with_alerts = self ._load_components_with_alerts ()
174+ components_with_alerts = load_components_with_alerts (
175+ self .config .target_path ,
176+ self .config .reach_output_file ,
177+ )
297178 if not components_with_alerts :
298179 self .logger .error (
299180 "Unable to generate full-scope SARIF: .socket.facts.json missing or invalid"
@@ -305,24 +186,19 @@ def output_console_sarif(self, diff_report: Diff, sbom_file_name: Optional[str]
305186 grouping = sarif_grouping ,
306187 )
307188 else :
308- if sarif_reachability == "reachable" :
309- filtered_alerts = self ._filter_sarif_reachable_alerts (diff_report .new_alerts )
310- diff_report = Diff (
311- new_alerts = filtered_alerts ,
312- diff_url = getattr (diff_report , "diff_url" , "" ),
313- new_packages = getattr (diff_report , "new_packages" , []),
314- removed_packages = getattr (diff_report , "removed_packages" , []),
315- packages = getattr (diff_report , "packages" , {}),
316- )
317- diff_report .id = "filtered"
318- elif sarif_reachability != "all" :
319- self .logger .warning (
320- "Reachability filter '%s' is only supported in full SARIF scope; output is unfiltered in diff scope" ,
321- sarif_reachability ,
322- )
189+ selected_alerts = select_diff_alerts (diff_report , strict_blocking = self .config .strict_blocking )
190+ filtered_alerts = filter_alerts_by_reachability (
191+ selected_alerts ,
192+ sarif_reachability ,
193+ self .config .target_path ,
194+ self .config .reach_output_file ,
195+ logger = self .logger ,
196+ fallback_to_blocking_for_reachable = True ,
197+ )
198+ selected_diff = clone_diff_with_selected_alerts (diff_report , filtered_alerts )
323199
324200 # Generate the SARIF structure using Messages
325- console_security_comment = Messages .create_security_comment_sarif (diff_report )
201+ console_security_comment = Messages .create_security_comment_sarif (selected_diff )
326202 self .save_sbom_file (diff_report , sbom_file_name )
327203 # Avoid flooding logs for full-scope SARIF when writing to file.
328204 if not (sarif_scope == "full" and self .config .sarif_file ):
0 commit comments