Skip to content

Commit 3cf14b6

Browse files
committed
Add shared selector/filter module
Signed-off-by: lelia <lelia@socket.dev>
1 parent 98e8ee5 commit 3cf14b6

File tree

1 file changed

+239
-0
lines changed

1 file changed

+239
-0
lines changed
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import logging
2+
from pathlib import Path
3+
from typing import Any, Dict, List, Optional, Set, Tuple
4+
5+
from socketsecurity.core.classes import Diff, Issue
6+
from socketsecurity.core.helper.socket_facts_loader import (
7+
convert_to_alerts,
8+
get_components_with_vulnerabilities,
9+
load_socket_facts,
10+
)
11+
from socketsecurity.core.messages import Messages
12+
13+
14+
def select_diff_alerts(diff: Diff, strict_blocking: bool = False) -> List[Issue]:
15+
"""Select diff alerts for output rendering.
16+
17+
In strict blocking mode, include unchanged alerts so rendered output aligns
18+
with pass/fail policy evaluation.
19+
"""
20+
selected = list(getattr(diff, "new_alerts", []) or [])
21+
if strict_blocking:
22+
selected.extend(getattr(diff, "unchanged_alerts", []) or [])
23+
return selected
24+
25+
26+
def clone_diff_with_selected_alerts(diff: Diff, selected_alerts: List[Issue]) -> Diff:
27+
"""Clone a diff object while replacing new_alerts with selected alerts."""
28+
selected_diff = Diff(
29+
new_alerts=selected_alerts,
30+
unchanged_alerts=[],
31+
removed_alerts=[],
32+
diff_url=getattr(diff, "diff_url", ""),
33+
new_packages=getattr(diff, "new_packages", []),
34+
removed_packages=getattr(diff, "removed_packages", []),
35+
packages=getattr(diff, "packages", {}),
36+
)
37+
selected_diff.id = getattr(diff, "id", "")
38+
selected_diff.report_url = getattr(diff, "report_url", "")
39+
selected_diff.new_scan_id = getattr(diff, "new_scan_id", "")
40+
return selected_diff
41+
42+
43+
def load_components_with_alerts(
44+
target_path: Optional[str],
45+
reach_output_file: Optional[str],
46+
) -> Optional[List[Dict[str, Any]]]:
47+
facts_file = reach_output_file or ".socket.facts.json"
48+
facts_file_path = str(Path(target_path or ".") / facts_file)
49+
facts_data = load_socket_facts(facts_file_path)
50+
if not facts_data:
51+
return None
52+
53+
components = get_components_with_vulnerabilities(facts_data)
54+
return convert_to_alerts(components)
55+
56+
57+
def _normalize_purl(purl: str) -> str:
58+
if not purl:
59+
return ""
60+
normalized = purl.strip().lower().replace("%40", "@")
61+
if normalized.startswith("pkg:"):
62+
normalized = normalized[4:]
63+
return normalized
64+
65+
66+
def _normalize_vuln_id(vuln_id: str) -> str:
67+
if not vuln_id:
68+
return ""
69+
return vuln_id.strip().upper()
70+
71+
72+
def _normalize_pkg_key(pkg_type: str, pkg_name: str, pkg_version: str) -> Tuple[str, str, str]:
73+
return (
74+
(pkg_type or "").strip().lower(),
75+
(pkg_name or "").strip().lower(),
76+
(pkg_version or "").strip().lower(),
77+
)
78+
79+
80+
def _extract_issue_vuln_ids(issue: Issue) -> Set[str]:
81+
ids: Set[str] = set()
82+
props = getattr(issue, "props", None) or {}
83+
for key in ("ghsaId", "ghsa_id", "cveId", "cve_id"):
84+
value = props.get(key)
85+
if isinstance(value, str) and value.strip():
86+
ids.add(_normalize_vuln_id(value))
87+
return ids
88+
89+
90+
def _is_potentially_reachable(reachability: str, undeterminable: bool = False) -> bool:
91+
normalized = Messages._normalize_reachability(reachability)
92+
potential_states = {"unknown", "error", "maybe_reachable", "potentially_reachable"}
93+
return normalized in potential_states or undeterminable
94+
95+
96+
def _matches_selector(states: Set[str], selector: str) -> bool:
97+
selected = (selector or "all").strip().lower()
98+
if selected == "all":
99+
return True
100+
if not states:
101+
return False
102+
if selected == "reachable":
103+
return "reachable" in states
104+
if selected == "potentially":
105+
return any(_is_potentially_reachable(state) for state in states)
106+
if selected == "reachable-or-potentially":
107+
return "reachable" in states or any(_is_potentially_reachable(state) for state in states)
108+
return True
109+
110+
111+
def _build_reachability_index(
112+
components_with_alerts: Optional[List[Dict[str, Any]]],
113+
) -> Optional[Tuple[Dict[str, Dict[str, Set[str]]], Dict[Tuple[str, str, str], Dict[str, Set[str]]]]]:
114+
if not components_with_alerts:
115+
return None
116+
117+
by_purl: Dict[str, Dict[str, Set[str]]] = {}
118+
by_pkg: Dict[Tuple[str, str, str], Dict[str, Set[str]]] = {}
119+
120+
for component in components_with_alerts:
121+
component_alerts = component.get("alerts", [])
122+
pkg_type = component.get("type", "")
123+
pkg_version = component.get("version", "")
124+
namespace = (component.get("namespace") or "").strip()
125+
name = (component.get("name") or component.get("id") or "").strip()
126+
127+
pkg_names: Set[str] = {name}
128+
if namespace:
129+
pkg_names.add(f"{namespace}/{name}")
130+
131+
for alert in component_alerts:
132+
props = alert.get("props", {}) or {}
133+
reachability = Messages._normalize_reachability(props.get("reachability", "unknown"))
134+
vuln_ids = {
135+
_normalize_vuln_id(props.get("ghsaId", "")),
136+
_normalize_vuln_id(props.get("cveId", "")),
137+
}
138+
vuln_ids = {v for v in vuln_ids if v}
139+
purl = _normalize_purl(props.get("purl", ""))
140+
141+
def _add(container: Dict[Any, Dict[str, Set[str]]], key: Any) -> None:
142+
if key not in container:
143+
container[key] = {}
144+
vuln_key = next(iter(vuln_ids)) if len(vuln_ids) == 1 else "*"
145+
if vuln_key not in container[key]:
146+
container[key][vuln_key] = set()
147+
container[key][vuln_key].add(reachability)
148+
if vuln_ids and vuln_key == "*":
149+
for vuln_id in vuln_ids:
150+
if vuln_id not in container[key]:
151+
container[key][vuln_id] = set()
152+
container[key][vuln_id].add(reachability)
153+
if not vuln_ids:
154+
if "*" not in container[key]:
155+
container[key]["*"] = set()
156+
container[key]["*"].add(reachability)
157+
158+
if purl:
159+
_add(by_purl, purl)
160+
161+
for pkg_name in pkg_names:
162+
pkg_key = _normalize_pkg_key(pkg_type, pkg_name, pkg_version)
163+
_add(by_pkg, pkg_key)
164+
165+
return by_purl, by_pkg
166+
167+
168+
def _alert_reachability_states(
169+
alert: Issue,
170+
by_purl: Dict[str, Dict[str, Set[str]]],
171+
by_pkg: Dict[Tuple[str, str, str], Dict[str, Set[str]]],
172+
) -> Set[str]:
173+
states: Set[str] = set()
174+
alert_ids = _extract_issue_vuln_ids(alert)
175+
alert_purl = _normalize_purl(getattr(alert, "purl", ""))
176+
pkg_key = _normalize_pkg_key(
177+
getattr(alert, "pkg_type", ""),
178+
getattr(alert, "pkg_name", ""),
179+
getattr(alert, "pkg_version", ""),
180+
)
181+
182+
def _collect(index: Dict[Any, Dict[str, Set[str]]], key: Any) -> Set[str]:
183+
found: Set[str] = set()
184+
mapping = index.get(key, {})
185+
if not mapping:
186+
return found
187+
188+
if "*" in mapping:
189+
found.update(mapping["*"])
190+
191+
if alert_ids:
192+
for alert_id in alert_ids:
193+
if alert_id in mapping:
194+
found.update(mapping[alert_id])
195+
else:
196+
for value in mapping.values():
197+
found.update(value)
198+
return found
199+
200+
if alert_purl:
201+
states.update(_collect(by_purl, alert_purl))
202+
states.update(_collect(by_pkg, pkg_key))
203+
return states
204+
205+
206+
def filter_alerts_by_reachability(
207+
alerts: List[Issue],
208+
selector: str,
209+
target_path: Optional[str],
210+
reach_output_file: Optional[str],
211+
logger: Optional[logging.Logger] = None,
212+
fallback_to_blocking_for_reachable: bool = True,
213+
) -> List[Issue]:
214+
"""
215+
Filter issue alerts by reachability selector using .socket.facts.json data.
216+
217+
If facts data is unavailable and selector is `reachable`, optionally falls back
218+
to `issue.error == True` for backward compatibility.
219+
"""
220+
normalized_selector = (selector or "all").strip().lower()
221+
if normalized_selector == "all":
222+
return list(alerts)
223+
224+
components_with_alerts = load_components_with_alerts(target_path, reach_output_file)
225+
reachability_index = _build_reachability_index(components_with_alerts)
226+
if not reachability_index:
227+
if logger:
228+
logger.warning("Unable to load reachability facts for selector '%s'", normalized_selector)
229+
if normalized_selector == "reachable" and fallback_to_blocking_for_reachable:
230+
return [a for a in alerts if getattr(a, "error", False)]
231+
return []
232+
233+
by_purl, by_pkg = reachability_index
234+
filtered: List[Issue] = []
235+
for alert in alerts:
236+
states = _alert_reachability_states(alert, by_purl, by_pkg)
237+
if _matches_selector(states, normalized_selector):
238+
filtered.append(alert)
239+
return filtered

0 commit comments

Comments
 (0)