Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion localstack-typedb/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ format: ## Run ruff to format the whole codebase
$(VENV_RUN); python -m ruff format .; python -m ruff check --output-format=full --fix .

test: ## Run integration tests (requires LocalStack running with the Extension installed)
$(VENV_RUN); pytest tests
$(VENV_RUN); pytest tests $(PYTEST_ARGS)

clean-dist: clean
rm -rf dist/
Expand Down
6 changes: 5 additions & 1 deletion localstack-typedb/localstack_typedb/extension.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import shlex

from localstack.config import is_env_not_false
from localstack.utils.docker_utils import DOCKER_CLIENT
from localstack_typedb.utils.docker import ProxiedDockerContainerExtension
from rolo import Request

# environment variable for user-defined command args to pass to TypeDB
ENV_CMD_FLAGS = "TYPEDB_FLAGS"
# environment variable for flag to enable/disable HTTP2 proxy for gRPC traffic
ENV_HTTP2_PROXY = "TYPEDB_HTTP2_PROXY"


class TypeDbExtension(ProxiedDockerContainerExtension):
Expand All @@ -24,13 +27,14 @@ def __init__(self):
command_flags = (os.environ.get(ENV_CMD_FLAGS) or "").strip()
command_flags = self.DEFAULT_CMD_FLAGS + shlex.split(command_flags)
command = self._get_image_command() + command_flags
http2_ports = [self.TYPEDB_PORT] if is_env_not_false(ENV_HTTP2_PROXY) else []
super().__init__(
image_name=self.DOCKER_IMAGE,
container_ports=[8000, 1729],
host=self.HOST,
request_to_port_router=self.request_to_port_router,
command=command,
http2_ports=[self.TYPEDB_PORT],
http2_ports=http2_ports,
)

def _get_image_command(self) -> list[str]:
Expand Down
98 changes: 95 additions & 3 deletions localstack-typedb/localstack_typedb/utils/h2_proxy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
import socket

from h2.frame_buffer import FrameBuffer
from hpack import Decoder
from hyperframe.frame import HeadersFrame
from twisted.internet import reactor

from localstack.utils.patch import patch
from twisted.web._http2 import H2Connection


LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -35,6 +39,15 @@ def receive_loop(self, callback):
def send(self, data):
self._socket.sendall(data)

def close(self):
LOG.debug("Closing connection to upstream HTTP2 server on port %s", self.port)
try:
self._socket.shutdown(socket.SHUT_RDWR)
self._socket.close()
except Exception:
# swallow exceptions here (e.g., "bad file descriptor")
pass


def apply_http2_patches_for_grpc_support(target_host: str, target_port: int):
"""
Expand All @@ -58,7 +71,86 @@ def _process(data):
@patch(H2Connection.dataReceived)
def _dataReceived(fn, self, data, *args, **kwargs):
forwarder = getattr(self, "_ls_forwarder", None)
if not forwarder:
is_typedb_grpc_request = getattr(self, "_is_typedb_grpc_request", None)
if not forwarder or is_typedb_grpc_request is False:
return fn(self, data, *args, **kwargs)
LOG.debug("Forwarding data (%s bytes) from HTTP2 client to server", len(data))
forwarder.send(data)

if is_typedb_grpc_request:
forwarder.send(data)
return

setattr(self, "_data_received", getattr(self, "_data_received", []))
self._data_received.append(data)

# parse headers from request frames received so far
headers = get_headers_from_data_stream(self._data_received)
if not headers:
# if no headers received yet, then return (method will be called again for next chunk of data)
return

# determine if this is a gRPC request targeting TypeDB - TODO make configurable!
content_type = headers.get("content-type") or ""
req_path = headers.get(":path") or ""
self._is_typedb_grpc_request = (
"grpc" in content_type and "/typedb.protocol.TypeDB" in req_path
)

if not self._is_typedb_grpc_request:
# if this is not a target request, then call the upstream function
result = None
for chunk in self._data_received:
result = fn(self, chunk, *args, **kwargs)
self._data_received = []
return result

# forward data chunks to the target
for chunk in self._data_received:
LOG.debug(
"Forwarding data (%s bytes) from HTTP2 client to server", len(chunk)
)
forwarder.send(chunk)
self._data_received = []

@patch(H2Connection.connectionLost)
def connectionLost(fn, self, *args, **kwargs):
forwarder = getattr(self, "_ls_forwarder", None)
if not forwarder:
return fn(self, *args, **kwargs)
forwarder.close()


def get_headers_from_data_stream(data_list: list) -> dict:
"""Get headers from a data stream (list of bytes data), if any headers are contained."""
data_combined = b"".join(data_list)
frames = parse_http2_stream(data_combined)
headers = get_headers_from_frames(frames)
return headers


def get_headers_from_frames(frames: list) -> dict:
"""Parse the given list of HTTP2 frames and return a dict of headers, if any"""
result = {}
decoder = Decoder()
for frame in frames:
if isinstance(frame, HeadersFrame):
try:
headers = decoder.decode(frame.data)
result.update(dict(headers))
except Exception:
pass
return result


def parse_http2_stream(data: bytes) -> list:
"""Parse the data from an HTTP2 stream into a list of frames"""
frames = []
buffer = FrameBuffer(server=True)
buffer.max_frame_size = 16384
buffer.add_data(data)
try:
for frame in buffer:
frames.append(frame)
except Exception:
pass

return frames
6 changes: 5 additions & 1 deletion localstack-typedb/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ authors = [
keywords = ["LocalStack", "TypeDB"]
classifiers = []
dependencies = [
"httpx"
"httpx",
"h2",
"priority",
]

[project.urls]
Expand All @@ -23,11 +25,13 @@ Homepage = "https://github.com/whummer/localstack-utils"
[project.optional-dependencies]
dev = [
"boto3",
"build",
"jsonpatch",
"localstack",
"pytest",
"rolo",
"ruff",
"twisted",
"typedb-driver",
]

Expand Down
15 changes: 15 additions & 0 deletions localstack-typedb/tests/test_extension.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import requests
from localstack.utils.strings import short_uid
from localstack_typedb.utils.h2_proxy import parse_http2_stream, get_headers_from_frames
from typedb.driver import TypeDB, Credentials, DriverOptions, TransactionType


Expand Down Expand Up @@ -67,3 +68,17 @@ def test_connect_to_db_via_grpc_endpoint():
).resolve()
for json in results:
print(json)


def test_parse_http2_frames():
# note: the data below is a dump taken from a browser request made against the emulator
data = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00\x18\x04\x00\x00\x00\x00\x00\x00\x01\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x04\x00\x02\x00\x00\x00\x05\x00\x00@\x00\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\xbf\x00\x01"
data += b"\x00\x01V\x01%\x00\x00\x00\x03\x00\x00\x00\x00\x15C\x87\xd5\xaf~MZw\x7f\x05\x8eb*\x0eA\xd0\x84\x8c\x9dX\x9c\xa3\xa13\xffA\x96\xa0\xe4\x1d\x13\x9d\t^\x83\x90t!#'U\xc9A\xed\x92\xe3M\xb8\xe7\x87z\xbe\xd0\x7ff\xa2\x81\xb0\xda\xe0S\xfa\xd02\x1a\xa4\x9d\x13\xfd\xa9\x92\xa4\x96\x854\x0c\x8aj\xdc\xa7\xe2\x81\x02\xe1o\xedK;\xdc\x0bM.\x0f\xedLE'S\xb0 \x04\x00\x08\x02\xa6\x13XYO\xe5\x80\xb4\xd2\xe0S\x83\xf9c\xe7Q\x8b-Kp\xdd\xf4Z\xbe\xfb@\x05\xdbP\x92\x9b\xd9\xab\xfaRB\xcb@\xd2_\xa5#\xb3\xe9OhL\x9f@\x94\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb5%L\xe7\x93\x83\xc5\x83\x7f@\x95\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb4\xe5\x1c\x85\xb1\x1f\x89\x1d\xa9\x9c\xf6\x1b\xd8\xd2c\xd5s\x95\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb#\x1f@\x85=\x86\x98\xd5\x7f\x94\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb'@\x8aAH\xb4\xa5I'ZB\xa1?\x84-5\xa7\xd7@\x8aAH\xb4\xa5I'Z\x93\xc8_\x83!\xecG@\x8aAH\xb4\xa5I'Y\x06I\x7f\x86@\xe9*\xc82K@\x86\xae\xc3\x1e\xc3'\xd7\x83\xb6\x06\xbf@\x82I\x7f\x86M\x835\x05\xb1\x1f\x00\x00\x04\x08\x00\x00\x00\x00\x03\x00\xbe\x00\x00"

frames = parse_http2_stream(data)
assert frames
headers = get_headers_from_frames(frames)
assert headers
assert headers[":scheme"] == "https"
assert headers[":method"] == "OPTIONS"
assert headers[":path"] == "/_localstack/health"