Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions dojo/tools/risk_recon/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import requests
from django.conf import settings

from dojo.utils_ssrf import SSRFError, make_ssrf_safe_session, validate_url_for_ssrf


class RiskReconAPI:
def __init__(self, api_key, endpoint, data):
Expand All @@ -26,7 +27,14 @@ def __init__(self, api_key, endpoint, data):
raise Exception(msg)
if self.url.endswith("/"):
self.url = endpoint[:-1]
self.session = requests.Session()

try:
validate_url_for_ssrf(self.url)
except SSRFError as exc:
msg = f"Invalid Risk Recon API url: {exc}"
raise Exception(msg) from exc

self.session = make_ssrf_safe_session()
self.map_toes()
self.get_findings()

Expand Down Expand Up @@ -54,7 +62,7 @@ def map_toes(self):
filters = comps.get(name)
self.toe_map[toe_id] = filters or self.data
else:
msg = f"Unable to query Target of Evaluations due to {response.status_code} - {response.content}"
msg = f"Unable to query Target of Evaluations due to {response.status_code}"
raise Exception(msg) # TODO: when implementing ruff BLE001, please fix also TODO in unittests/test_risk_recon.py

def filter_finding(self, finding):
Expand Down Expand Up @@ -86,5 +94,5 @@ def get_findings(self):
if not self.filter_finding(finding):
self.findings.append(finding)
else:
msg = f"Unable to collect findings from toe: {toe} due to {response.status_code} - {response.content}"
msg = f"Unable to collect findings from toe: {toe} due to {response.status_code}"
raise Exception(msg)
172 changes: 172 additions & 0 deletions dojo/utils_ssrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
SSRF (Server-Side Request Forgery) protection utilities.

Provides a requests.Session that validates outbound URLs against private/reserved
IP ranges at socket-creation time, closing the DNS rebinding (TOCTOU) window that
exists when validation is performed only as a pre-flight step.

Usage:
from dojo.utils_ssrf import make_ssrf_safe_session, validate_url_for_ssrf, SSRFError

# Pre-flight validation (raises SSRFError with a human-readable message):
validate_url_for_ssrf(url)

# Safe session (validates at socket-creation time on every request):
session = make_ssrf_safe_session()
response = session.get(url)
"""

import ipaddress
import socket
from urllib.parse import urlparse

import requests
import urllib3.connection
import urllib3.connectionpool
from requests.adapters import DEFAULT_POOLBLOCK, DEFAULT_POOLSIZE, HTTPAdapter


class SSRFError(ValueError):

"""Raised when a URL is determined to be unsafe for server-side requests."""


_ALLOWED_SCHEMES = frozenset({"http", "https"})


def _check_ip(ip_str: str) -> None:
"""Raise SSRFError if the IP address is not globally routable."""
try:
ip = ipaddress.ip_address(ip_str)
except ValueError as exc:
msg = f"Cannot parse IP address: {ip_str!r}"
raise SSRFError(msg) from exc

# ip.is_global is False for loopback, link-local (169.254.x.x), RFC 1918,
# reserved, multicast, and unspecified addresses.
if not ip.is_global:
msg = (
f"Blocked: URL resolved to non-public address {ip}. "
"Requests to private, loopback, link-local, or reserved "
"addresses are not permitted."
)
raise SSRFError(msg)


def _resolve_and_check(hostname: str, port: int) -> None:
"""Resolve hostname and verify every returned address is publicly routable."""
try:
addr_infos = socket.getaddrinfo(
hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM,
)
except socket.gaierror as exc:
msg = f"Unable to resolve hostname {hostname!r}: {exc}"
raise SSRFError(msg) from exc

if not addr_infos:
msg = f"No addresses returned for hostname {hostname!r}"
raise SSRFError(msg)

for _family, _type, _proto, _canon, sockaddr in addr_infos:
_check_ip(sockaddr[0])


def validate_url_for_ssrf(url: str) -> None:
"""
Pre-flight SSRF validation for a URL.

Checks:
- Scheme is http or https (blocks file://, gopher://, etc.)
- Every resolved IP address is globally routable (blocks RFC 1918,
loopback 127.x, link-local 169.254.x.x, and other reserved ranges)

Raises SSRFError with a descriptive message if the URL is unsafe.
This is a best-effort pre-flight check; use make_ssrf_safe_session() for
socket-level enforcement that also mitigates DNS rebinding.
"""
try:
parsed = urlparse(url)
except Exception as exc:
msg = f"Malformed URL: {url!r}"
raise SSRFError(msg) from exc

if parsed.scheme not in _ALLOWED_SCHEMES:
msg = (
f"URL scheme {parsed.scheme!r} is not permitted. "
"Only 'http' and 'https' are allowed."
)
raise SSRFError(msg)

hostname = parsed.hostname
if not hostname:
msg = f"URL has no hostname: {url!r}"
raise SSRFError(msg)

port = parsed.port or (443 if parsed.scheme == "https" else 80)
_resolve_and_check(hostname, port)


# ---------------------------------------------------------------------------
# urllib3 connection subclasses — validation runs at socket-creation time.
# Overriding _new_conn() (called immediately before the OS connect() syscall)
# minimises the TOCTOU window to microseconds, making DNS rebinding attacks
# impractical in practice.
# ---------------------------------------------------------------------------

class _SSRFSafeHTTPConnection(urllib3.connection.HTTPConnection):
def _new_conn(self) -> socket.socket:
_resolve_and_check(self._dns_host, self.port)
return super()._new_conn()


class _SSRFSafeHTTPSConnection(urllib3.connection.HTTPSConnection):
def _new_conn(self) -> socket.socket:
_resolve_and_check(self._dns_host, self.port)
return super()._new_conn()


class _SSRFSafeHTTPConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
ConnectionCls = _SSRFSafeHTTPConnection


class _SSRFSafeHTTPSConnectionPool(urllib3.connectionpool.HTTPSConnectionPool):
ConnectionCls = _SSRFSafeHTTPSConnection


_SAFE_POOL_CLASSES = {
"http": _SSRFSafeHTTPConnectionPool,
"https": _SSRFSafeHTTPSConnectionPool,
}


class _SSRFSafeAdapter(HTTPAdapter):

"""
A requests HTTPAdapter that injects SSRF-safe connection classes into the
urllib3 pool manager so that IP validation happens at socket-creation time
on every request, including after redirects.
"""

def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs):
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
# Replace the pool classes after the manager is created.
# pool_classes_by_scheme is a plain dict on the instance, so this
# only affects this adapter's pool manager.
self.poolmanager.pool_classes_by_scheme = _SAFE_POOL_CLASSES


def make_ssrf_safe_session() -> requests.Session:
"""
Return a requests.Session with SSRF protection applied at the socket level.

Every outbound request made through this session will have its resolved IP
validated against the private/reserved range blocklist immediately before
the OS socket is opened, preventing both:
- Direct requests to internal IP ranges
- DNS rebinding attacks
"""
session = requests.Session()
adapter = _SSRFSafeAdapter(pool_maxsize=DEFAULT_POOLSIZE)
session.mount("http://", adapter)
session.mount("https://", adapter)
return session
75 changes: 75 additions & 0 deletions unittests/test_utils_ssrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import socket
from unittest.mock import patch

import requests

from dojo.utils_ssrf import SSRFError, _SSRFSafeAdapter, make_ssrf_safe_session, validate_url_for_ssrf # noqa: PLC2701
from unittests.dojo_test_case import DojoTestCase


def _addr_info(ip, port=80):
"""Build a minimal getaddrinfo-style return value for a single IP."""
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, port))]


_MIXED_ADDR_INFO = [
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 80)),
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.168.1.1", 80)),
]


class TestValidateUrlForSsrf(DojoTestCase):

@patch("dojo.utils_ssrf.socket.getaddrinfo", return_value=_addr_info("8.8.8.8"))
def test_valid_public_url_does_not_raise(self, mock_getaddrinfo):
validate_url_for_ssrf("http://example.com/api") # should not raise

def test_file_scheme_raises(self):
with self.assertRaisesRegex(SSRFError, "not permitted"):
validate_url_for_ssrf("file:///etc/passwd")

def test_gopher_scheme_raises(self):
with self.assertRaisesRegex(SSRFError, "not permitted"):
validate_url_for_ssrf("gopher://example.com")

def test_no_hostname_raises(self):
with self.assertRaisesRegex(SSRFError, "no hostname"):
validate_url_for_ssrf("http://")

def test_loopback_ip_raises(self):
with self.assertRaisesRegex(SSRFError, "non-public address"):
validate_url_for_ssrf("http://127.0.0.1/")

def test_private_class_c_raises(self):
with self.assertRaisesRegex(SSRFError, "non-public address"):
validate_url_for_ssrf("http://192.168.1.1/")

def test_private_class_a_raises(self):
with self.assertRaisesRegex(SSRFError, "non-public address"):
validate_url_for_ssrf("http://10.0.0.1/")

def test_link_local_raises(self):
with self.assertRaisesRegex(SSRFError, "non-public address"):
validate_url_for_ssrf("http://169.254.1.1/")

@patch("dojo.utils_ssrf.socket.getaddrinfo", side_effect=socket.gaierror("Name or service not known"))
def test_unresolvable_hostname_raises(self, mock_getaddrinfo):
with self.assertRaisesRegex(SSRFError, "Unable to resolve"):
validate_url_for_ssrf("http://nonexistent.invalid/")

@patch("dojo.utils_ssrf.socket.getaddrinfo", return_value=_MIXED_ADDR_INFO)
def test_multi_address_with_private_ip_raises(self, mock_getaddrinfo):
with self.assertRaisesRegex(SSRFError, "non-public address"):
validate_url_for_ssrf("http://example.com/")


class TestMakeSsrfSafeSession(DojoTestCase):

def test_returns_requests_session(self):
session = make_ssrf_safe_session()
self.assertIsInstance(session, requests.Session)

def test_http_and_https_mounted_with_safe_adapter(self):
session = make_ssrf_safe_session()
self.assertIsInstance(session.get_adapter("http://example.com"), _SSRFSafeAdapter)
self.assertIsInstance(session.get_adapter("https://example.com"), _SSRFSafeAdapter)
24 changes: 21 additions & 3 deletions unittests/tools/test_risk_recon_parser.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import datetime

import requests
from unittest.mock import MagicMock, patch

from dojo.models import Test
from dojo.tools.risk_recon.api import RiskReconAPI
from dojo.tools.risk_recon.parser import RiskReconParser
from dojo.utils_ssrf import SSRFError
from unittests.dojo_test_case import DojoTestCase, get_unit_tests_scans_path


class TestRiskReconAPIParser(DojoTestCase):

def test_api_with_bad_url(self):
with (get_unit_tests_scans_path("risk_recon") / "bad_url.json").open(encoding="utf-8") as testfile, \
self.assertRaises(requests.exceptions.ConnectionError):
self.assertRaises(Exception): # noqa: B017 # SSRFError is caught and re-raised as Exception in api.py
parser = RiskReconParser()
parser.get_findings(testfile, Test())

Expand All @@ -34,3 +35,20 @@ def test_parser_without_api(self):
finding = findings[1]
self.assertEqual(datetime.date(2017, 3, 17), finding.date.date())
self.assertEqual("ff2bbdbfc2b6gsrgwergwe6b1fasfwefb", finding.unique_id_from_tool)

@patch("dojo.tools.risk_recon.api.validate_url_for_ssrf", side_effect=SSRFError("blocked: private address"))
def test_ssrf_error_is_raised_as_exception(self, mock_validate):
with self.assertRaisesRegex(Exception, "Invalid Risk Recon API url"):
RiskReconAPI(api_key="somekey", endpoint="http://192.168.1.1/api", data=[])
mock_validate.assert_called_once_with("http://192.168.1.1/api")

@patch.object(RiskReconAPI, "get_findings")
@patch.object(RiskReconAPI, "map_toes")
@patch("dojo.tools.risk_recon.api.make_ssrf_safe_session")
@patch("dojo.tools.risk_recon.api.validate_url_for_ssrf")
def test_make_ssrf_safe_session_called_on_init(self, mock_validate, mock_make_session, mock_map_toes, mock_get_findings):
mock_session = MagicMock()
mock_make_session.return_value = mock_session
api = RiskReconAPI(api_key="somekey", endpoint="https://api.riskrecon.com/v1", data=[])
mock_make_session.assert_called_once()
self.assertIs(api.session, mock_session)
Loading