diff --git a/localstack-typedb/Makefile b/localstack-typedb/Makefile index 658ceff..109efd7 100644 --- a/localstack-typedb/Makefile +++ b/localstack-typedb/Makefile @@ -36,7 +36,7 @@ format: ## Run ruff to format the whole codebase $(VENV_RUN); python -m ruff format .; python -m ruff check --output-format=full --fix . test: ## Run integration tests (requires LocalStack running with the Extension installed) - $(VENV_RUN); pytest tests + $(VENV_RUN); pytest tests $(PYTEST_ARGS) clean-dist: clean rm -rf dist/ diff --git a/localstack-typedb/localstack_typedb/extension.py b/localstack-typedb/localstack_typedb/extension.py index 98aba14..d77e6fd 100644 --- a/localstack-typedb/localstack_typedb/extension.py +++ b/localstack-typedb/localstack_typedb/extension.py @@ -1,12 +1,15 @@ import os import shlex +from localstack.config import is_env_not_false from localstack.utils.docker_utils import DOCKER_CLIENT from localstack_typedb.utils.docker import ProxiedDockerContainerExtension from rolo import Request # environment variable for user-defined command args to pass to TypeDB ENV_CMD_FLAGS = "TYPEDB_FLAGS" +# environment variable for flag to enable/disable HTTP2 proxy for gRPC traffic +ENV_HTTP2_PROXY = "TYPEDB_HTTP2_PROXY" class TypeDbExtension(ProxiedDockerContainerExtension): @@ -24,13 +27,14 @@ def __init__(self): command_flags = (os.environ.get(ENV_CMD_FLAGS) or "").strip() command_flags = self.DEFAULT_CMD_FLAGS + shlex.split(command_flags) command = self._get_image_command() + command_flags + http2_ports = [self.TYPEDB_PORT] if is_env_not_false(ENV_HTTP2_PROXY) else [] super().__init__( image_name=self.DOCKER_IMAGE, container_ports=[8000, 1729], host=self.HOST, request_to_port_router=self.request_to_port_router, command=command, - http2_ports=[self.TYPEDB_PORT], + http2_ports=http2_ports, ) def _get_image_command(self) -> list[str]: diff --git a/localstack-typedb/localstack_typedb/utils/h2_proxy.py b/localstack-typedb/localstack_typedb/utils/h2_proxy.py index 85720c3..23bd21a 100644 --- a/localstack-typedb/localstack_typedb/utils/h2_proxy.py +++ b/localstack-typedb/localstack_typedb/utils/h2_proxy.py @@ -1,11 +1,15 @@ import logging import socket +from h2.frame_buffer import FrameBuffer +from hpack import Decoder +from hyperframe.frame import HeadersFrame from twisted.internet import reactor from localstack.utils.patch import patch from twisted.web._http2 import H2Connection + LOG = logging.getLogger(__name__) @@ -35,6 +39,15 @@ def receive_loop(self, callback): def send(self, data): self._socket.sendall(data) + def close(self): + LOG.debug("Closing connection to upstream HTTP2 server on port %s", self.port) + try: + self._socket.shutdown(socket.SHUT_RDWR) + self._socket.close() + except Exception: + # swallow exceptions here (e.g., "bad file descriptor") + pass + def apply_http2_patches_for_grpc_support(target_host: str, target_port: int): """ @@ -58,7 +71,86 @@ def _process(data): @patch(H2Connection.dataReceived) def _dataReceived(fn, self, data, *args, **kwargs): forwarder = getattr(self, "_ls_forwarder", None) - if not forwarder: + is_typedb_grpc_request = getattr(self, "_is_typedb_grpc_request", None) + if not forwarder or is_typedb_grpc_request is False: return fn(self, data, *args, **kwargs) - LOG.debug("Forwarding data (%s bytes) from HTTP2 client to server", len(data)) - forwarder.send(data) + + if is_typedb_grpc_request: + forwarder.send(data) + return + + setattr(self, "_data_received", getattr(self, "_data_received", [])) + self._data_received.append(data) + + # parse headers from request frames received so far + headers = get_headers_from_data_stream(self._data_received) + if not headers: + # if no headers received yet, then return (method will be called again for next chunk of data) + return + + # determine if this is a gRPC request targeting TypeDB - TODO make configurable! + content_type = headers.get("content-type") or "" + req_path = headers.get(":path") or "" + self._is_typedb_grpc_request = ( + "grpc" in content_type and "/typedb.protocol.TypeDB" in req_path + ) + + if not self._is_typedb_grpc_request: + # if this is not a target request, then call the upstream function + result = None + for chunk in self._data_received: + result = fn(self, chunk, *args, **kwargs) + self._data_received = [] + return result + + # forward data chunks to the target + for chunk in self._data_received: + LOG.debug( + "Forwarding data (%s bytes) from HTTP2 client to server", len(chunk) + ) + forwarder.send(chunk) + self._data_received = [] + + @patch(H2Connection.connectionLost) + def connectionLost(fn, self, *args, **kwargs): + forwarder = getattr(self, "_ls_forwarder", None) + if not forwarder: + return fn(self, *args, **kwargs) + forwarder.close() + + +def get_headers_from_data_stream(data_list: list) -> dict: + """Get headers from a data stream (list of bytes data), if any headers are contained.""" + data_combined = b"".join(data_list) + frames = parse_http2_stream(data_combined) + headers = get_headers_from_frames(frames) + return headers + + +def get_headers_from_frames(frames: list) -> dict: + """Parse the given list of HTTP2 frames and return a dict of headers, if any""" + result = {} + decoder = Decoder() + for frame in frames: + if isinstance(frame, HeadersFrame): + try: + headers = decoder.decode(frame.data) + result.update(dict(headers)) + except Exception: + pass + return result + + +def parse_http2_stream(data: bytes) -> list: + """Parse the data from an HTTP2 stream into a list of frames""" + frames = [] + buffer = FrameBuffer(server=True) + buffer.max_frame_size = 16384 + buffer.add_data(data) + try: + for frame in buffer: + frames.append(frame) + except Exception: + pass + + return frames diff --git a/localstack-typedb/pyproject.toml b/localstack-typedb/pyproject.toml index 85a2357..307c3fb 100644 --- a/localstack-typedb/pyproject.toml +++ b/localstack-typedb/pyproject.toml @@ -14,7 +14,9 @@ authors = [ keywords = ["LocalStack", "TypeDB"] classifiers = [] dependencies = [ - "httpx" + "httpx", + "h2", + "priority", ] [project.urls] @@ -23,11 +25,13 @@ Homepage = "https://github.com/whummer/localstack-utils" [project.optional-dependencies] dev = [ "boto3", + "build", "jsonpatch", "localstack", "pytest", "rolo", "ruff", + "twisted", "typedb-driver", ] diff --git a/localstack-typedb/tests/test_extension.py b/localstack-typedb/tests/test_extension.py index 9fab030..4d8c3bc 100644 --- a/localstack-typedb/tests/test_extension.py +++ b/localstack-typedb/tests/test_extension.py @@ -1,5 +1,6 @@ import requests from localstack.utils.strings import short_uid +from localstack_typedb.utils.h2_proxy import parse_http2_stream, get_headers_from_frames from typedb.driver import TypeDB, Credentials, DriverOptions, TransactionType @@ -67,3 +68,17 @@ def test_connect_to_db_via_grpc_endpoint(): ).resolve() for json in results: print(json) + + +def test_parse_http2_frames(): + # note: the data below is a dump taken from a browser request made against the emulator + data = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00\x18\x04\x00\x00\x00\x00\x00\x00\x01\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x04\x00\x02\x00\x00\x00\x05\x00\x00@\x00\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\xbf\x00\x01" + data += b"\x00\x01V\x01%\x00\x00\x00\x03\x00\x00\x00\x00\x15C\x87\xd5\xaf~MZw\x7f\x05\x8eb*\x0eA\xd0\x84\x8c\x9dX\x9c\xa3\xa13\xffA\x96\xa0\xe4\x1d\x13\x9d\t^\x83\x90t!#'U\xc9A\xed\x92\xe3M\xb8\xe7\x87z\xbe\xd0\x7ff\xa2\x81\xb0\xda\xe0S\xfa\xd02\x1a\xa4\x9d\x13\xfd\xa9\x92\xa4\x96\x854\x0c\x8aj\xdc\xa7\xe2\x81\x02\xe1o\xedK;\xdc\x0bM.\x0f\xedLE'S\xb0 \x04\x00\x08\x02\xa6\x13XYO\xe5\x80\xb4\xd2\xe0S\x83\xf9c\xe7Q\x8b-Kp\xdd\xf4Z\xbe\xfb@\x05\xdbP\x92\x9b\xd9\xab\xfaRB\xcb@\xd2_\xa5#\xb3\xe9OhL\x9f@\x94\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb5%L\xe7\x93\x83\xc5\x83\x7f@\x95\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb4\xe5\x1c\x85\xb1\x1f\x89\x1d\xa9\x9c\xf6\x1b\xd8\xd2c\xd5s\x95\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb#\x1f@\x85=\x86\x98\xd5\x7f\x94\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb'@\x8aAH\xb4\xa5I'ZB\xa1?\x84-5\xa7\xd7@\x8aAH\xb4\xa5I'Z\x93\xc8_\x83!\xecG@\x8aAH\xb4\xa5I'Y\x06I\x7f\x86@\xe9*\xc82K@\x86\xae\xc3\x1e\xc3'\xd7\x83\xb6\x06\xbf@\x82I\x7f\x86M\x835\x05\xb1\x1f\x00\x00\x04\x08\x00\x00\x00\x00\x03\x00\xbe\x00\x00" + + frames = parse_http2_stream(data) + assert frames + headers = get_headers_from_frames(frames) + assert headers + assert headers[":scheme"] == "https" + assert headers[":method"] == "OPTIONS" + assert headers[":path"] == "/_localstack/health"