|
| 1 | +import logging |
| 2 | +from pathlib import Path |
| 3 | +from typing import Any, Dict, List, Optional, Set, Tuple |
| 4 | + |
| 5 | +from socketsecurity.core.classes import Diff, Issue |
| 6 | +from socketsecurity.core.helper.socket_facts_loader import ( |
| 7 | + convert_to_alerts, |
| 8 | + get_components_with_vulnerabilities, |
| 9 | + load_socket_facts, |
| 10 | +) |
| 11 | +from socketsecurity.core.messages import Messages |
| 12 | + |
| 13 | + |
| 14 | +def select_diff_alerts(diff: Diff, strict_blocking: bool = False) -> List[Issue]: |
| 15 | + """Select diff alerts for output rendering. |
| 16 | +
|
| 17 | + In strict blocking mode, include unchanged alerts so rendered output aligns |
| 18 | + with pass/fail policy evaluation. |
| 19 | + """ |
| 20 | + selected = list(getattr(diff, "new_alerts", []) or []) |
| 21 | + if strict_blocking: |
| 22 | + selected.extend(getattr(diff, "unchanged_alerts", []) or []) |
| 23 | + return selected |
| 24 | + |
| 25 | + |
| 26 | +def clone_diff_with_selected_alerts(diff: Diff, selected_alerts: List[Issue]) -> Diff: |
| 27 | + """Clone a diff object while replacing new_alerts with selected alerts.""" |
| 28 | + selected_diff = Diff( |
| 29 | + new_alerts=selected_alerts, |
| 30 | + unchanged_alerts=[], |
| 31 | + removed_alerts=[], |
| 32 | + diff_url=getattr(diff, "diff_url", ""), |
| 33 | + new_packages=getattr(diff, "new_packages", []), |
| 34 | + removed_packages=getattr(diff, "removed_packages", []), |
| 35 | + packages=getattr(diff, "packages", {}), |
| 36 | + ) |
| 37 | + selected_diff.id = getattr(diff, "id", "") |
| 38 | + selected_diff.report_url = getattr(diff, "report_url", "") |
| 39 | + selected_diff.new_scan_id = getattr(diff, "new_scan_id", "") |
| 40 | + return selected_diff |
| 41 | + |
| 42 | + |
| 43 | +def load_components_with_alerts( |
| 44 | + target_path: Optional[str], |
| 45 | + reach_output_file: Optional[str], |
| 46 | +) -> Optional[List[Dict[str, Any]]]: |
| 47 | + facts_file = reach_output_file or ".socket.facts.json" |
| 48 | + facts_file_path = str(Path(target_path or ".") / facts_file) |
| 49 | + facts_data = load_socket_facts(facts_file_path) |
| 50 | + if not facts_data: |
| 51 | + return None |
| 52 | + |
| 53 | + components = get_components_with_vulnerabilities(facts_data) |
| 54 | + return convert_to_alerts(components) |
| 55 | + |
| 56 | + |
| 57 | +def _normalize_purl(purl: str) -> str: |
| 58 | + if not purl: |
| 59 | + return "" |
| 60 | + normalized = purl.strip().lower().replace("%40", "@") |
| 61 | + if normalized.startswith("pkg:"): |
| 62 | + normalized = normalized[4:] |
| 63 | + return normalized |
| 64 | + |
| 65 | + |
| 66 | +def _normalize_vuln_id(vuln_id: str) -> str: |
| 67 | + if not vuln_id: |
| 68 | + return "" |
| 69 | + return vuln_id.strip().upper() |
| 70 | + |
| 71 | + |
| 72 | +def _normalize_pkg_key(pkg_type: str, pkg_name: str, pkg_version: str) -> Tuple[str, str, str]: |
| 73 | + return ( |
| 74 | + (pkg_type or "").strip().lower(), |
| 75 | + (pkg_name or "").strip().lower(), |
| 76 | + (pkg_version or "").strip().lower(), |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +def _extract_issue_vuln_ids(issue: Issue) -> Set[str]: |
| 81 | + ids: Set[str] = set() |
| 82 | + props = getattr(issue, "props", None) or {} |
| 83 | + for key in ("ghsaId", "ghsa_id", "cveId", "cve_id"): |
| 84 | + value = props.get(key) |
| 85 | + if isinstance(value, str) and value.strip(): |
| 86 | + ids.add(_normalize_vuln_id(value)) |
| 87 | + return ids |
| 88 | + |
| 89 | + |
| 90 | +def _is_potentially_reachable(reachability: str, undeterminable: bool = False) -> bool: |
| 91 | + normalized = Messages._normalize_reachability(reachability) |
| 92 | + potential_states = {"unknown", "error", "maybe_reachable", "potentially_reachable"} |
| 93 | + return normalized in potential_states or undeterminable |
| 94 | + |
| 95 | + |
| 96 | +def _matches_selector(states: Set[str], selector: str) -> bool: |
| 97 | + selected = (selector or "all").strip().lower() |
| 98 | + if selected == "all": |
| 99 | + return True |
| 100 | + if not states: |
| 101 | + return False |
| 102 | + if selected == "reachable": |
| 103 | + return "reachable" in states |
| 104 | + if selected == "potentially": |
| 105 | + return any(_is_potentially_reachable(state) for state in states) |
| 106 | + if selected == "reachable-or-potentially": |
| 107 | + return "reachable" in states or any(_is_potentially_reachable(state) for state in states) |
| 108 | + return True |
| 109 | + |
| 110 | + |
| 111 | +def _build_reachability_index( |
| 112 | + components_with_alerts: Optional[List[Dict[str, Any]]], |
| 113 | +) -> Optional[Tuple[Dict[str, Dict[str, Set[str]]], Dict[Tuple[str, str, str], Dict[str, Set[str]]]]]: |
| 114 | + if not components_with_alerts: |
| 115 | + return None |
| 116 | + |
| 117 | + by_purl: Dict[str, Dict[str, Set[str]]] = {} |
| 118 | + by_pkg: Dict[Tuple[str, str, str], Dict[str, Set[str]]] = {} |
| 119 | + |
| 120 | + for component in components_with_alerts: |
| 121 | + component_alerts = component.get("alerts", []) |
| 122 | + pkg_type = component.get("type", "") |
| 123 | + pkg_version = component.get("version", "") |
| 124 | + namespace = (component.get("namespace") or "").strip() |
| 125 | + name = (component.get("name") or component.get("id") or "").strip() |
| 126 | + |
| 127 | + pkg_names: Set[str] = {name} |
| 128 | + if namespace: |
| 129 | + pkg_names.add(f"{namespace}/{name}") |
| 130 | + |
| 131 | + for alert in component_alerts: |
| 132 | + props = alert.get("props", {}) or {} |
| 133 | + reachability = Messages._normalize_reachability(props.get("reachability", "unknown")) |
| 134 | + vuln_ids = { |
| 135 | + _normalize_vuln_id(props.get("ghsaId", "")), |
| 136 | + _normalize_vuln_id(props.get("cveId", "")), |
| 137 | + } |
| 138 | + vuln_ids = {v for v in vuln_ids if v} |
| 139 | + purl = _normalize_purl(props.get("purl", "")) |
| 140 | + |
| 141 | + def _add(container: Dict[Any, Dict[str, Set[str]]], key: Any) -> None: |
| 142 | + if key not in container: |
| 143 | + container[key] = {} |
| 144 | + vuln_key = next(iter(vuln_ids)) if len(vuln_ids) == 1 else "*" |
| 145 | + if vuln_key not in container[key]: |
| 146 | + container[key][vuln_key] = set() |
| 147 | + container[key][vuln_key].add(reachability) |
| 148 | + if vuln_ids and vuln_key == "*": |
| 149 | + for vuln_id in vuln_ids: |
| 150 | + if vuln_id not in container[key]: |
| 151 | + container[key][vuln_id] = set() |
| 152 | + container[key][vuln_id].add(reachability) |
| 153 | + if not vuln_ids: |
| 154 | + if "*" not in container[key]: |
| 155 | + container[key]["*"] = set() |
| 156 | + container[key]["*"].add(reachability) |
| 157 | + |
| 158 | + if purl: |
| 159 | + _add(by_purl, purl) |
| 160 | + |
| 161 | + for pkg_name in pkg_names: |
| 162 | + pkg_key = _normalize_pkg_key(pkg_type, pkg_name, pkg_version) |
| 163 | + _add(by_pkg, pkg_key) |
| 164 | + |
| 165 | + return by_purl, by_pkg |
| 166 | + |
| 167 | + |
| 168 | +def _alert_reachability_states( |
| 169 | + alert: Issue, |
| 170 | + by_purl: Dict[str, Dict[str, Set[str]]], |
| 171 | + by_pkg: Dict[Tuple[str, str, str], Dict[str, Set[str]]], |
| 172 | +) -> Set[str]: |
| 173 | + states: Set[str] = set() |
| 174 | + alert_ids = _extract_issue_vuln_ids(alert) |
| 175 | + alert_purl = _normalize_purl(getattr(alert, "purl", "")) |
| 176 | + pkg_key = _normalize_pkg_key( |
| 177 | + getattr(alert, "pkg_type", ""), |
| 178 | + getattr(alert, "pkg_name", ""), |
| 179 | + getattr(alert, "pkg_version", ""), |
| 180 | + ) |
| 181 | + |
| 182 | + def _collect(index: Dict[Any, Dict[str, Set[str]]], key: Any) -> Set[str]: |
| 183 | + found: Set[str] = set() |
| 184 | + mapping = index.get(key, {}) |
| 185 | + if not mapping: |
| 186 | + return found |
| 187 | + |
| 188 | + if "*" in mapping: |
| 189 | + found.update(mapping["*"]) |
| 190 | + |
| 191 | + if alert_ids: |
| 192 | + for alert_id in alert_ids: |
| 193 | + if alert_id in mapping: |
| 194 | + found.update(mapping[alert_id]) |
| 195 | + else: |
| 196 | + for value in mapping.values(): |
| 197 | + found.update(value) |
| 198 | + return found |
| 199 | + |
| 200 | + if alert_purl: |
| 201 | + states.update(_collect(by_purl, alert_purl)) |
| 202 | + states.update(_collect(by_pkg, pkg_key)) |
| 203 | + return states |
| 204 | + |
| 205 | + |
| 206 | +def filter_alerts_by_reachability( |
| 207 | + alerts: List[Issue], |
| 208 | + selector: str, |
| 209 | + target_path: Optional[str], |
| 210 | + reach_output_file: Optional[str], |
| 211 | + logger: Optional[logging.Logger] = None, |
| 212 | + fallback_to_blocking_for_reachable: bool = True, |
| 213 | +) -> List[Issue]: |
| 214 | + """ |
| 215 | + Filter issue alerts by reachability selector using .socket.facts.json data. |
| 216 | +
|
| 217 | + If facts data is unavailable and selector is `reachable`, optionally falls back |
| 218 | + to `issue.error == True` for backward compatibility. |
| 219 | + """ |
| 220 | + normalized_selector = (selector or "all").strip().lower() |
| 221 | + if normalized_selector == "all": |
| 222 | + return list(alerts) |
| 223 | + |
| 224 | + components_with_alerts = load_components_with_alerts(target_path, reach_output_file) |
| 225 | + reachability_index = _build_reachability_index(components_with_alerts) |
| 226 | + if not reachability_index: |
| 227 | + if logger: |
| 228 | + logger.warning("Unable to load reachability facts for selector '%s'", normalized_selector) |
| 229 | + if normalized_selector == "reachable" and fallback_to_blocking_for_reachable: |
| 230 | + return [a for a in alerts if getattr(a, "error", False)] |
| 231 | + return [] |
| 232 | + |
| 233 | + by_purl, by_pkg = reachability_index |
| 234 | + filtered: List[Issue] = [] |
| 235 | + for alert in alerts: |
| 236 | + states = _alert_reachability_states(alert, by_purl, by_pkg) |
| 237 | + if _matches_selector(states, normalized_selector): |
| 238 | + filtered.append(alert) |
| 239 | + return filtered |
0 commit comments