Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions localstack-typedb/localstack_typedb/utils/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from localstack import config
from localstack.config import is_env_true
from localstack_typedb.utils.h2_proxy import apply_http2_patches_for_grpc_support
from localstack_typedb.utils.h2_proxy import (
apply_http2_patches_for_grpc_support,
ProxyRequestMatcher,
)
from localstack.utils.docker_utils import DOCKER_CLIENT
from localstack.extensions.api import Extension, http
from localstack.http import Request
Expand All @@ -16,6 +19,7 @@
from rolo import route
from rolo.proxy import Proxy
from rolo.routing import RuleAdapter, WithHost
from werkzeug.datastructures import Headers

LOG = logging.getLogger(__name__)
logging.getLogger("localstack_typedb").setLevel(
Expand All @@ -24,7 +28,7 @@
logging.basicConfig()


class ProxiedDockerContainerExtension(Extension):
class ProxiedDockerContainerExtension(Extension, ProxyRequestMatcher):
name: str
"""Name of this extension"""
image_name: str
Expand Down Expand Up @@ -82,7 +86,9 @@ def update_gateway_routes(self, router: http.Router[http.RouteHandler]):

# apply patches to serve HTTP/2 requests
for port in self.http2_ports or []:
apply_http2_patches_for_grpc_support(get_addressable_container_host(), port)
apply_http2_patches_for_grpc_support(
get_addressable_container_host(), port, self
)

def on_platform_shutdown(self):
self._remove_container()
Expand All @@ -94,6 +100,15 @@ def _get_container_name(self) -> str:
name = re.sub(r"\W", "-", name)
return name

def should_proxy_request(self, headers: Headers) -> bool:
# determine if this is a gRPC request targeting TypeDB
content_type = headers.get("content-type") or ""
req_path = headers.get(":path") or ""
is_typedb_grpc_request = (
"grpc" in content_type and "/typedb.protocol.TypeDB" in req_path
)
return is_typedb_grpc_request

@cache
def start_container(self) -> None:
container_name = self._get_container_name()
Expand Down
48 changes: 29 additions & 19 deletions localstack-typedb/localstack_typedb/utils/h2_proxy.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,35 @@
import logging
import socket
from abc import abstractmethod

from h2.frame_buffer import FrameBuffer
from hpack import Decoder
from hyperframe.frame import HeadersFrame
from hyperframe.frame import HeadersFrame, Frame
from twisted.internet import reactor

from localstack.utils.patch import patch
from twisted.web._http2 import H2Connection

from werkzeug.datastructures import Headers

LOG = logging.getLogger(__name__)


class ProxyRequestMatcher:
"""
Abstract base class that defines a request matcher, for an extension to define which incoming
request messages should be proxied to an upstream target (and which ones shouldn't).
"""

@abstractmethod
def should_proxy_request(self, headers: Headers) -> bool:
"""Define whether a request should be proxied, based on request headers."""


class TcpForwarder:
"""Simple helper class for bidirectional forwarding of TPC traffic."""

buffer_size = 1024
buffer_size: int = 1024
"""Data buffer size for receiving data from upstream socket."""

def __init__(self, port: int, host: str = "localhost"):
self.port = port
Expand Down Expand Up @@ -49,7 +62,9 @@ def close(self):
pass


def apply_http2_patches_for_grpc_support(target_host: str, target_port: int):
def apply_http2_patches_for_grpc_support(
target_host: str, target_port: int, request_matcher: ProxyRequestMatcher
):
"""
Apply some patches to proxy incoming gRPC requests and forward them to a target port.
Note: this is a very brute-force approach and needs to be fixed/enhanced over time!
Expand All @@ -71,11 +86,11 @@ def _process(data):
@patch(H2Connection.dataReceived)
def _dataReceived(fn, self, data, *args, **kwargs):
forwarder = getattr(self, "_ls_forwarder", None)
is_typedb_grpc_request = getattr(self, "_is_typedb_grpc_request", None)
if not forwarder or is_typedb_grpc_request is False:
should_proxy_request = getattr(self, "_ls_should_proxy_request", None)
if not forwarder or should_proxy_request is False:
return fn(self, data, *args, **kwargs)

if is_typedb_grpc_request:
if should_proxy_request:
forwarder.send(data)
return

Expand All @@ -88,14 +103,10 @@ def _dataReceived(fn, self, data, *args, **kwargs):
# if no headers received yet, then return (method will be called again for next chunk of data)
return

# determine if this is a gRPC request targeting TypeDB - TODO make configurable!
content_type = headers.get("content-type") or ""
req_path = headers.get(":path") or ""
self._is_typedb_grpc_request = (
"grpc" in content_type and "/typedb.protocol.TypeDB" in req_path
)
# check if the incoming request should be proxies, based on the request headers
self._ls_should_proxy_request = request_matcher.should_proxy_request(headers)

if not self._is_typedb_grpc_request:
if not self._ls_should_proxy_request:
# if this is not a target request, then call the upstream function
result = None
for chunk in self._data_received:
Expand All @@ -119,15 +130,15 @@ def connectionLost(fn, self, *args, **kwargs):
forwarder.close()


def get_headers_from_data_stream(data_list: list) -> dict:
def get_headers_from_data_stream(data_list: list[bytes]) -> Headers:
"""Get headers from a data stream (list of bytes data), if any headers are contained."""
data_combined = b"".join(data_list)
frames = parse_http2_stream(data_combined)
headers = get_headers_from_frames(frames)
return headers


def get_headers_from_frames(frames: list) -> dict:
def get_headers_from_frames(frames: list[Frame]) -> Headers:
"""Parse the given list of HTTP2 frames and return a dict of headers, if any"""
result = {}
decoder = Decoder()
Expand All @@ -138,10 +149,10 @@ def get_headers_from_frames(frames: list) -> dict:
result.update(dict(headers))
except Exception:
pass
return result
return Headers(result)


def parse_http2_stream(data: bytes) -> list:
def parse_http2_stream(data: bytes) -> list[Frame]:
"""Parse the data from an HTTP2 stream into a list of frames"""
frames = []
buffer = FrameBuffer(server=True)
Expand All @@ -152,5 +163,4 @@ def parse_http2_stream(data: bytes) -> list:
frames.append(frame)
except Exception:
pass

return frames