From fc624babb570f07d240e9e1de64ae8faca6bd97d Mon Sep 17 00:00:00 2001 From: Waldemar Hummer Date: Thu, 6 Nov 2025 21:38:33 +0200 Subject: [PATCH] add ProxyRequestMatcher to make proxy request filter generically reusable --- .../localstack_typedb/utils/docker.py | 21 ++++++-- .../localstack_typedb/utils/h2_proxy.py | 48 +++++++++++-------- 2 files changed, 47 insertions(+), 22 deletions(-) diff --git a/localstack-typedb/localstack_typedb/utils/docker.py b/localstack-typedb/localstack_typedb/utils/docker.py index ceef35e..807be19 100644 --- a/localstack-typedb/localstack_typedb/utils/docker.py +++ b/localstack-typedb/localstack_typedb/utils/docker.py @@ -6,7 +6,10 @@ from localstack import config from localstack.config import is_env_true -from localstack_typedb.utils.h2_proxy import apply_http2_patches_for_grpc_support +from localstack_typedb.utils.h2_proxy import ( + apply_http2_patches_for_grpc_support, + ProxyRequestMatcher, +) from localstack.utils.docker_utils import DOCKER_CLIENT from localstack.extensions.api import Extension, http from localstack.http import Request @@ -16,6 +19,7 @@ from rolo import route from rolo.proxy import Proxy from rolo.routing import RuleAdapter, WithHost +from werkzeug.datastructures import Headers LOG = logging.getLogger(__name__) logging.getLogger("localstack_typedb").setLevel( @@ -24,7 +28,7 @@ logging.basicConfig() -class ProxiedDockerContainerExtension(Extension): +class ProxiedDockerContainerExtension(Extension, ProxyRequestMatcher): name: str """Name of this extension""" image_name: str @@ -82,7 +86,9 @@ def update_gateway_routes(self, router: http.Router[http.RouteHandler]): # apply patches to serve HTTP/2 requests for port in self.http2_ports or []: - apply_http2_patches_for_grpc_support(get_addressable_container_host(), port) + apply_http2_patches_for_grpc_support( + get_addressable_container_host(), port, self + ) def on_platform_shutdown(self): self._remove_container() @@ -94,6 +100,15 @@ def _get_container_name(self) -> str: name = re.sub(r"\W", "-", name) return name + def should_proxy_request(self, headers: Headers) -> bool: + # determine if this is a gRPC request targeting TypeDB + content_type = headers.get("content-type") or "" + req_path = headers.get(":path") or "" + is_typedb_grpc_request = ( + "grpc" in content_type and "/typedb.protocol.TypeDB" in req_path + ) + return is_typedb_grpc_request + @cache def start_container(self) -> None: container_name = self._get_container_name() diff --git a/localstack-typedb/localstack_typedb/utils/h2_proxy.py b/localstack-typedb/localstack_typedb/utils/h2_proxy.py index 23bd21a..2beccca 100644 --- a/localstack-typedb/localstack_typedb/utils/h2_proxy.py +++ b/localstack-typedb/localstack_typedb/utils/h2_proxy.py @@ -1,22 +1,35 @@ import logging import socket +from abc import abstractmethod from h2.frame_buffer import FrameBuffer from hpack import Decoder -from hyperframe.frame import HeadersFrame +from hyperframe.frame import HeadersFrame, Frame from twisted.internet import reactor from localstack.utils.patch import patch from twisted.web._http2 import H2Connection - +from werkzeug.datastructures import Headers LOG = logging.getLogger(__name__) +class ProxyRequestMatcher: + """ + Abstract base class that defines a request matcher, for an extension to define which incoming + request messages should be proxied to an upstream target (and which ones shouldn't). + """ + + @abstractmethod + def should_proxy_request(self, headers: Headers) -> bool: + """Define whether a request should be proxied, based on request headers.""" + + class TcpForwarder: """Simple helper class for bidirectional forwarding of TPC traffic.""" - buffer_size = 1024 + buffer_size: int = 1024 + """Data buffer size for receiving data from upstream socket.""" def __init__(self, port: int, host: str = "localhost"): self.port = port @@ -49,7 +62,9 @@ def close(self): pass -def apply_http2_patches_for_grpc_support(target_host: str, target_port: int): +def apply_http2_patches_for_grpc_support( + target_host: str, target_port: int, request_matcher: ProxyRequestMatcher +): """ Apply some patches to proxy incoming gRPC requests and forward them to a target port. Note: this is a very brute-force approach and needs to be fixed/enhanced over time! @@ -71,11 +86,11 @@ def _process(data): @patch(H2Connection.dataReceived) def _dataReceived(fn, self, data, *args, **kwargs): forwarder = getattr(self, "_ls_forwarder", None) - is_typedb_grpc_request = getattr(self, "_is_typedb_grpc_request", None) - if not forwarder or is_typedb_grpc_request is False: + should_proxy_request = getattr(self, "_ls_should_proxy_request", None) + if not forwarder or should_proxy_request is False: return fn(self, data, *args, **kwargs) - if is_typedb_grpc_request: + if should_proxy_request: forwarder.send(data) return @@ -88,14 +103,10 @@ def _dataReceived(fn, self, data, *args, **kwargs): # if no headers received yet, then return (method will be called again for next chunk of data) return - # determine if this is a gRPC request targeting TypeDB - TODO make configurable! - content_type = headers.get("content-type") or "" - req_path = headers.get(":path") or "" - self._is_typedb_grpc_request = ( - "grpc" in content_type and "/typedb.protocol.TypeDB" in req_path - ) + # check if the incoming request should be proxies, based on the request headers + self._ls_should_proxy_request = request_matcher.should_proxy_request(headers) - if not self._is_typedb_grpc_request: + if not self._ls_should_proxy_request: # if this is not a target request, then call the upstream function result = None for chunk in self._data_received: @@ -119,7 +130,7 @@ def connectionLost(fn, self, *args, **kwargs): forwarder.close() -def get_headers_from_data_stream(data_list: list) -> dict: +def get_headers_from_data_stream(data_list: list[bytes]) -> Headers: """Get headers from a data stream (list of bytes data), if any headers are contained.""" data_combined = b"".join(data_list) frames = parse_http2_stream(data_combined) @@ -127,7 +138,7 @@ def get_headers_from_data_stream(data_list: list) -> dict: return headers -def get_headers_from_frames(frames: list) -> dict: +def get_headers_from_frames(frames: list[Frame]) -> Headers: """Parse the given list of HTTP2 frames and return a dict of headers, if any""" result = {} decoder = Decoder() @@ -138,10 +149,10 @@ def get_headers_from_frames(frames: list) -> dict: result.update(dict(headers)) except Exception: pass - return result + return Headers(result) -def parse_http2_stream(data: bytes) -> list: +def parse_http2_stream(data: bytes) -> list[Frame]: """Parse the data from an HTTP2 stream into a list of frames""" frames = [] buffer = FrameBuffer(server=True) @@ -152,5 +163,4 @@ def parse_http2_stream(data: bytes) -> list: frames.append(frame) except Exception: pass - return frames