diff --git a/dojo/tools/risk_recon/api.py b/dojo/tools/risk_recon/api.py index d9691d41aa9..5bf26d83115 100644 --- a/dojo/tools/risk_recon/api.py +++ b/dojo/tools/risk_recon/api.py @@ -1,6 +1,7 @@ -import requests from django.conf import settings +from dojo.utils_ssrf import SSRFError, make_ssrf_safe_session, validate_url_for_ssrf + class RiskReconAPI: def __init__(self, api_key, endpoint, data): @@ -26,7 +27,14 @@ def __init__(self, api_key, endpoint, data): raise Exception(msg) if self.url.endswith("/"): self.url = endpoint[:-1] - self.session = requests.Session() + + try: + validate_url_for_ssrf(self.url) + except SSRFError as exc: + msg = f"Invalid Risk Recon API url: {exc}" + raise Exception(msg) from exc + + self.session = make_ssrf_safe_session() self.map_toes() self.get_findings() @@ -54,7 +62,7 @@ def map_toes(self): filters = comps.get(name) self.toe_map[toe_id] = filters or self.data else: - msg = f"Unable to query Target of Evaluations due to {response.status_code} - {response.content}" + msg = f"Unable to query Target of Evaluations due to {response.status_code}" raise Exception(msg) # TODO: when implementing ruff BLE001, please fix also TODO in unittests/test_risk_recon.py def filter_finding(self, finding): @@ -86,5 +94,5 @@ def get_findings(self): if not self.filter_finding(finding): self.findings.append(finding) else: - msg = f"Unable to collect findings from toe: {toe} due to {response.status_code} - {response.content}" + msg = f"Unable to collect findings from toe: {toe} due to {response.status_code}" raise Exception(msg) diff --git a/dojo/utils_ssrf.py b/dojo/utils_ssrf.py new file mode 100644 index 00000000000..42334e0173b --- /dev/null +++ b/dojo/utils_ssrf.py @@ -0,0 +1,172 @@ +""" +SSRF (Server-Side Request Forgery) protection utilities. + +Provides a requests.Session that validates outbound URLs against private/reserved +IP ranges at socket-creation time, closing the DNS rebinding (TOCTOU) window that +exists when validation is performed only as a pre-flight step. + +Usage: + from dojo.utils_ssrf import make_ssrf_safe_session, validate_url_for_ssrf, SSRFError + + # Pre-flight validation (raises SSRFError with a human-readable message): + validate_url_for_ssrf(url) + + # Safe session (validates at socket-creation time on every request): + session = make_ssrf_safe_session() + response = session.get(url) +""" + +import ipaddress +import socket +from urllib.parse import urlparse + +import requests +import urllib3.connection +import urllib3.connectionpool +from requests.adapters import DEFAULT_POOLBLOCK, DEFAULT_POOLSIZE, HTTPAdapter + + +class SSRFError(ValueError): + + """Raised when a URL is determined to be unsafe for server-side requests.""" + + +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + + +def _check_ip(ip_str: str) -> None: + """Raise SSRFError if the IP address is not globally routable.""" + try: + ip = ipaddress.ip_address(ip_str) + except ValueError as exc: + msg = f"Cannot parse IP address: {ip_str!r}" + raise SSRFError(msg) from exc + + # ip.is_global is False for loopback, link-local (169.254.x.x), RFC 1918, + # reserved, multicast, and unspecified addresses. + if not ip.is_global: + msg = ( + f"Blocked: URL resolved to non-public address {ip}. " + "Requests to private, loopback, link-local, or reserved " + "addresses are not permitted." + ) + raise SSRFError(msg) + + +def _resolve_and_check(hostname: str, port: int) -> None: + """Resolve hostname and verify every returned address is publicly routable.""" + try: + addr_infos = socket.getaddrinfo( + hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM, + ) + except socket.gaierror as exc: + msg = f"Unable to resolve hostname {hostname!r}: {exc}" + raise SSRFError(msg) from exc + + if not addr_infos: + msg = f"No addresses returned for hostname {hostname!r}" + raise SSRFError(msg) + + for _family, _type, _proto, _canon, sockaddr in addr_infos: + _check_ip(sockaddr[0]) + + +def validate_url_for_ssrf(url: str) -> None: + """ + Pre-flight SSRF validation for a URL. + + Checks: + - Scheme is http or https (blocks file://, gopher://, etc.) + - Every resolved IP address is globally routable (blocks RFC 1918, + loopback 127.x, link-local 169.254.x.x, and other reserved ranges) + + Raises SSRFError with a descriptive message if the URL is unsafe. + This is a best-effort pre-flight check; use make_ssrf_safe_session() for + socket-level enforcement that also mitigates DNS rebinding. + """ + try: + parsed = urlparse(url) + except Exception as exc: + msg = f"Malformed URL: {url!r}" + raise SSRFError(msg) from exc + + if parsed.scheme not in _ALLOWED_SCHEMES: + msg = ( + f"URL scheme {parsed.scheme!r} is not permitted. " + "Only 'http' and 'https' are allowed." + ) + raise SSRFError(msg) + + hostname = parsed.hostname + if not hostname: + msg = f"URL has no hostname: {url!r}" + raise SSRFError(msg) + + port = parsed.port or (443 if parsed.scheme == "https" else 80) + _resolve_and_check(hostname, port) + + +# --------------------------------------------------------------------------- +# urllib3 connection subclasses — validation runs at socket-creation time. +# Overriding _new_conn() (called immediately before the OS connect() syscall) +# minimises the TOCTOU window to microseconds, making DNS rebinding attacks +# impractical in practice. +# --------------------------------------------------------------------------- + +class _SSRFSafeHTTPConnection(urllib3.connection.HTTPConnection): + def _new_conn(self) -> socket.socket: + _resolve_and_check(self._dns_host, self.port) + return super()._new_conn() + + +class _SSRFSafeHTTPSConnection(urllib3.connection.HTTPSConnection): + def _new_conn(self) -> socket.socket: + _resolve_and_check(self._dns_host, self.port) + return super()._new_conn() + + +class _SSRFSafeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool): + ConnectionCls = _SSRFSafeHTTPConnection + + +class _SSRFSafeHTTPSConnectionPool(urllib3.connectionpool.HTTPSConnectionPool): + ConnectionCls = _SSRFSafeHTTPSConnection + + +_SAFE_POOL_CLASSES = { + "http": _SSRFSafeHTTPConnectionPool, + "https": _SSRFSafeHTTPSConnectionPool, +} + + +class _SSRFSafeAdapter(HTTPAdapter): + + """ + A requests HTTPAdapter that injects SSRF-safe connection classes into the + urllib3 pool manager so that IP validation happens at socket-creation time + on every request, including after redirects. + """ + + def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs): + super().init_poolmanager(connections, maxsize, block, **pool_kwargs) + # Replace the pool classes after the manager is created. + # pool_classes_by_scheme is a plain dict on the instance, so this + # only affects this adapter's pool manager. + self.poolmanager.pool_classes_by_scheme = _SAFE_POOL_CLASSES + + +def make_ssrf_safe_session() -> requests.Session: + """ + Return a requests.Session with SSRF protection applied at the socket level. + + Every outbound request made through this session will have its resolved IP + validated against the private/reserved range blocklist immediately before + the OS socket is opened, preventing both: + - Direct requests to internal IP ranges + - DNS rebinding attacks + """ + session = requests.Session() + adapter = _SSRFSafeAdapter(pool_maxsize=DEFAULT_POOLSIZE) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session diff --git a/unittests/test_utils_ssrf.py b/unittests/test_utils_ssrf.py new file mode 100644 index 00000000000..904abf8101c --- /dev/null +++ b/unittests/test_utils_ssrf.py @@ -0,0 +1,75 @@ +import socket +from unittest.mock import patch + +import requests + +from dojo.utils_ssrf import SSRFError, _SSRFSafeAdapter, make_ssrf_safe_session, validate_url_for_ssrf # noqa: PLC2701 +from unittests.dojo_test_case import DojoTestCase + + +def _addr_info(ip, port=80): + """Build a minimal getaddrinfo-style return value for a single IP.""" + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))] + + +_MIXED_ADDR_INFO = [ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 80)), + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.168.1.1", 80)), +] + + +class TestValidateUrlForSsrf(DojoTestCase): + + @patch("dojo.utils_ssrf.socket.getaddrinfo", return_value=_addr_info("8.8.8.8")) + def test_valid_public_url_does_not_raise(self, mock_getaddrinfo): + validate_url_for_ssrf("http://example.com/api") # should not raise + + def test_file_scheme_raises(self): + with self.assertRaisesRegex(SSRFError, "not permitted"): + validate_url_for_ssrf("file:///etc/passwd") + + def test_gopher_scheme_raises(self): + with self.assertRaisesRegex(SSRFError, "not permitted"): + validate_url_for_ssrf("gopher://example.com") + + def test_no_hostname_raises(self): + with self.assertRaisesRegex(SSRFError, "no hostname"): + validate_url_for_ssrf("http://") + + def test_loopback_ip_raises(self): + with self.assertRaisesRegex(SSRFError, "non-public address"): + validate_url_for_ssrf("http://127.0.0.1/") + + def test_private_class_c_raises(self): + with self.assertRaisesRegex(SSRFError, "non-public address"): + validate_url_for_ssrf("http://192.168.1.1/") + + def test_private_class_a_raises(self): + with self.assertRaisesRegex(SSRFError, "non-public address"): + validate_url_for_ssrf("http://10.0.0.1/") + + def test_link_local_raises(self): + with self.assertRaisesRegex(SSRFError, "non-public address"): + validate_url_for_ssrf("http://169.254.1.1/") + + @patch("dojo.utils_ssrf.socket.getaddrinfo", side_effect=socket.gaierror("Name or service not known")) + def test_unresolvable_hostname_raises(self, mock_getaddrinfo): + with self.assertRaisesRegex(SSRFError, "Unable to resolve"): + validate_url_for_ssrf("http://nonexistent.invalid/") + + @patch("dojo.utils_ssrf.socket.getaddrinfo", return_value=_MIXED_ADDR_INFO) + def test_multi_address_with_private_ip_raises(self, mock_getaddrinfo): + with self.assertRaisesRegex(SSRFError, "non-public address"): + validate_url_for_ssrf("http://example.com/") + + +class TestMakeSsrfSafeSession(DojoTestCase): + + def test_returns_requests_session(self): + session = make_ssrf_safe_session() + self.assertIsInstance(session, requests.Session) + + def test_http_and_https_mounted_with_safe_adapter(self): + session = make_ssrf_safe_session() + self.assertIsInstance(session.get_adapter("http://example.com"), _SSRFSafeAdapter) + self.assertIsInstance(session.get_adapter("https://example.com"), _SSRFSafeAdapter) diff --git a/unittests/tools/test_risk_recon_parser.py b/unittests/tools/test_risk_recon_parser.py index c59b39bd7d9..a2fec88180d 100644 --- a/unittests/tools/test_risk_recon_parser.py +++ b/unittests/tools/test_risk_recon_parser.py @@ -1,9 +1,10 @@ import datetime - -import requests +from unittest.mock import MagicMock, patch from dojo.models import Test +from dojo.tools.risk_recon.api import RiskReconAPI from dojo.tools.risk_recon.parser import RiskReconParser +from dojo.utils_ssrf import SSRFError from unittests.dojo_test_case import DojoTestCase, get_unit_tests_scans_path @@ -11,7 +12,7 @@ class TestRiskReconAPIParser(DojoTestCase): def test_api_with_bad_url(self): with (get_unit_tests_scans_path("risk_recon") / "bad_url.json").open(encoding="utf-8") as testfile, \ - self.assertRaises(requests.exceptions.ConnectionError): + self.assertRaises(Exception): # noqa: B017 # SSRFError is caught and re-raised as Exception in api.py parser = RiskReconParser() parser.get_findings(testfile, Test()) @@ -34,3 +35,20 @@ def test_parser_without_api(self): finding = findings[1] self.assertEqual(datetime.date(2017, 3, 17), finding.date.date()) self.assertEqual("ff2bbdbfc2b6gsrgwergwe6b1fasfwefb", finding.unique_id_from_tool) + + @patch("dojo.tools.risk_recon.api.validate_url_for_ssrf", side_effect=SSRFError("blocked: private address")) + def test_ssrf_error_is_raised_as_exception(self, mock_validate): + with self.assertRaisesRegex(Exception, "Invalid Risk Recon API url"): + RiskReconAPI(api_key="somekey", endpoint="http://192.168.1.1/api", data=[]) + mock_validate.assert_called_once_with("http://192.168.1.1/api") + + @patch.object(RiskReconAPI, "get_findings") + @patch.object(RiskReconAPI, "map_toes") + @patch("dojo.tools.risk_recon.api.make_ssrf_safe_session") + @patch("dojo.tools.risk_recon.api.validate_url_for_ssrf") + def test_make_ssrf_safe_session_called_on_init(self, mock_validate, mock_make_session, mock_map_toes, mock_get_findings): + mock_session = MagicMock() + mock_make_session.return_value = mock_session + api = RiskReconAPI(api_key="somekey", endpoint="https://api.riskrecon.com/v1", data=[]) + mock_make_session.assert_called_once() + self.assertIs(api.session, mock_session)