Skip to content

Commit b75b7f2

Browse files
committed
ADD: Add Python support for live compression
1 parent 367e27e commit b75b7f2

8 files changed

Lines changed: 179 additions & 6 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#### Enhancements
66
- Added `slow_reader_behavior` field to `AuthenticationRequest` message
77
- Added `SlowReaderBehavior` enum
8+
- Added support for using compression in the live API:
9+
- Added `compression` parameter to the `Live` client constructor
10+
- Added `compression` property to the `Live` client
11+
- Added `compression` field to `AuthenticationRequest`
812
- Upgraded `databento-dbn` to 0.49.0:
913
- Added support for decompressing Zstd in the Python `DBNDecoder` and new optional `compression` parameter
1014

databento/live/client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import databento_dbn
1616
import pandas as pd
17+
from databento_dbn import Compression
1718
from databento_dbn import DBNRecord
1819
from databento_dbn import Schema
1920
from databento_dbn import SType
@@ -68,6 +69,9 @@ class Live:
6869
The live gateway behavior when the client falls behind real time.
6970
- "skip": skip records to immediately catch up
7071
- "warn": send a slow reader warning `SystemMsg` but continue reading every record
72+
compression : Compression or str, default "none"
73+
The compression format for live data. Set to "zstd" for
74+
Zstandard-compressed data from the gateway.
7175
7276
"""
7377

@@ -88,6 +92,7 @@ def __init__(
8892
heartbeat_interval_s: int | None = None,
8993
reconnect_policy: ReconnectPolicy | str = ReconnectPolicy.NONE,
9094
slow_reader_behavior: SlowReaderBehavior | str | None = None,
95+
compression: Compression = Compression.NONE,
9196
) -> None:
9297
if key is None:
9398
key = os.environ.get("DATABENTO_API_KEY")
@@ -105,6 +110,7 @@ def __init__(
105110

106111
self._dataset: Dataset | str = ""
107112
self._ts_out = ts_out
113+
self._compression = compression
108114
self._heartbeat_interval_s = heartbeat_interval_s
109115

110116
self._metadata: SessionMetadata = SessionMetadata()
@@ -119,6 +125,7 @@ def __init__(
119125
user_port=port,
120126
reconnect_policy=reconnect_policy,
121127
slow_reader_behavior=slow_reader_behavior,
128+
compression=compression,
122129
)
123130

124131
self._session._user_callbacks.append(ClientRecordCallback(self._map_symbol))
@@ -298,6 +305,18 @@ def ts_out(self) -> bool:
298305
"""
299306
return self._ts_out
300307

308+
@property
309+
def compression(self) -> Compression:
310+
"""
311+
Returns the compression mode for this live client.
312+
313+
Returns
314+
-------
315+
Compression
316+
317+
"""
318+
return self._compression
319+
301320
def add_callback(
302321
self,
303322
record_callback: RecordCallback,

databento/live/gateway.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import SupportsBytes
88
from typing import TypeVar
99

10+
from databento_dbn import Compression
1011
from databento_dbn import Encoding
1112
from databento_dbn import Schema
1213
from databento_dbn import SType
@@ -118,6 +119,7 @@ class AuthenticationRequest(GatewayControl):
118119
encoding: Encoding = Encoding.DBN
119120
details: str | None = None
120121
ts_out: str = "0"
122+
compression: Compression | str = Compression.NONE
121123
heartbeat_interval_s: int | None = None
122124
slow_reader_behavior: SlowReaderBehavior | str | None = None
123125
client: str = USER_AGENT

databento/live/protocol.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Final
88

99
import databento_dbn
10+
from databento_dbn import Compression
1011
from databento_dbn import DBNRecord
1112
from databento_dbn import Metadata
1213
from databento_dbn import Schema
@@ -61,6 +62,8 @@ class DatabentoLiveProtocol(asyncio.BufferedProtocol):
6162
heartbeat_interval_s: int, optional
6263
The interval in seconds at which the gateway will send heartbeat records if no
6364
other data records are sent.
65+
compression : Compression, default Compression.NONE
66+
The compression format for the session.
6467
6568
See Also
6669
--------
@@ -75,6 +78,7 @@ def __init__(
7578
ts_out: bool = False,
7679
heartbeat_interval_s: int | None = None,
7780
slow_reader_behavior: SlowReaderBehavior | str | None = None,
81+
compression: Compression = Compression.NONE,
7882
) -> None:
7983
self.__api_key = api_key
8084
self.__transport: asyncio.Transport | None = None
@@ -84,9 +88,11 @@ def __init__(
8488
self._ts_out = ts_out
8589
self._heartbeat_interval_s = heartbeat_interval_s
8690
self._slow_reader_behavior: SlowReaderBehavior | str | None = slow_reader_behavior
91+
self._compression = compression
8792

8893
self._dbn_decoder = databento_dbn.DBNDecoder(
8994
upgrade_policy=VersionUpgradePolicy.UPGRADE_TO_V3,
95+
compression=compression,
9096
)
9197
self._gateway_decoder = GatewayDecoder()
9298

@@ -443,15 +449,17 @@ def _(self, message: ChallengeRequest) -> None:
443449
auth=response,
444450
dataset=self._dataset,
445451
ts_out=str(int(self._ts_out)),
452+
compression=str(self._compression).lower(),
446453
heartbeat_interval_s=self._heartbeat_interval_s,
447454
slow_reader_behavior=self._slow_reader_behavior,
448455
)
449456
logger.debug(
450-
"sending CRAM challenge response auth='%s' dataset=%s encoding=%s ts_out=%s heartbeat_interval_s=%s client='%s'",
457+
"sending CRAM challenge response auth='%s' dataset=%s encoding=%s ts_out=%s compression=%s heartbeat_interval_s=%s client='%s'",
451458
auth_request.auth,
452459
auth_request.dataset,
453460
auth_request.encoding,
454461
auth_request.ts_out,
462+
auth_request.compression,
455463
auth_request.heartbeat_interval_s,
456464
auth_request.client,
457465
)

databento/live/session.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import databento_dbn
1616
import pandas as pd
17+
from databento_dbn import Compression
1718
from databento_dbn import DBNRecord
1819
from databento_dbn import Schema
1920
from databento_dbn import SType
@@ -207,8 +208,16 @@ def __init__(
207208
ts_out: bool = False,
208209
heartbeat_interval_s: int | None = None,
209210
slow_reader_behavior: SlowReaderBehavior | str | None = None,
211+
compression: Compression = Compression.NONE,
210212
):
211-
super().__init__(api_key, dataset, ts_out, heartbeat_interval_s, slow_reader_behavior)
213+
super().__init__(
214+
api_key,
215+
dataset,
216+
ts_out,
217+
heartbeat_interval_s,
218+
slow_reader_behavior,
219+
compression,
220+
)
212221

213222
self._dbn_queue = dbn_queue
214223
self._loop = loop
@@ -304,6 +313,8 @@ class LiveSession:
304313
The reconnect policy for the live session.
305314
- "none": the client will not reconnect (default)
306315
- "reconnect": the client will reconnect automatically
316+
compression : Compression, optional
317+
The compression format for the session. Defaults to no compression.
307318
"""
308319

309320
def __init__(
@@ -316,6 +327,7 @@ def __init__(
316327
user_port: int = DEFAULT_REMOTE_PORT,
317328
reconnect_policy: ReconnectPolicy | str = ReconnectPolicy.NONE,
318329
slow_reader_behavior: SlowReaderBehavior | str | None = None,
330+
compression: Compression = Compression.NONE,
319331
) -> None:
320332
self._dbn_queue = DBNQueue()
321333
self._lock = threading.RLock()
@@ -333,6 +345,7 @@ def __init__(
333345
self._ts_out = ts_out
334346
self._heartbeat_interval_s = heartbeat_interval_s or 30
335347
self._slow_reader_behavior = slow_reader_behavior
348+
self._compression = compression
336349

337350
self._protocol: _SessionProtocol | None = None
338351
self._transport: asyncio.Transport | None = None
@@ -584,6 +597,7 @@ def _create_protocol(self, dataset: Dataset | str) -> _SessionProtocol:
584597
ts_out=self.ts_out,
585598
heartbeat_interval_s=self.heartbeat_interval_s,
586599
slow_reader_behavior=self._slow_reader_behavior,
600+
compression=self._compression,
587601
)
588602

589603
def _connect(

tests/mockliveserver/server.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from typing import Any
1414
from typing import Final
1515

16+
import zstandard
17+
from databento_dbn import Compression
1618
from databento_dbn import Schema
1719

1820
from databento.common import cram
@@ -79,6 +81,7 @@ def __init__(
7981
self._data = BytesIO()
8082
self._state = SessionState.NEW
8183
self._dataset: Dataset | None = None
84+
self._compression: Compression = Compression.NONE
8285
self._subscriptions: list[SubscriptionRequest] = []
8386
self._replay_tasks: set[asyncio.Task[None]] = set()
8487

@@ -194,6 +197,13 @@ def _(self, message: AuthenticationRequest) -> None:
194197

195198
self.state = SessionState.AUTHENTICATED
196199
self._dataset = Dataset(message.dataset)
200+
# Parse compression from the auth request
201+
compression_str = getattr(message, "compression", "none")
202+
if compression_str == "zstd":
203+
self._compression = Compression.ZSTD
204+
else:
205+
self._compression = Compression.NONE
206+
logger.debug("client requested compression=%s", compression_str)
197207
self.send_gateway_message(
198208
self.get_authentication_response(
199209
success=True,
@@ -297,15 +307,37 @@ def _replay_done_callback(self, task: asyncio.Task[Any]) -> None:
297307
self.hangup(reason="all replay tasks completed")
298308

299309
async def _file_replay_task(self) -> None:
310+
compressor = None
311+
if self._compression == Compression.ZSTD:
312+
compressor = zstandard.ZstdCompressor()
313+
cctx = compressor.compressobj()
314+
300315
for subscription in self._subscriptions:
301316
schema = (
302317
Schema.from_str(subscription.schema)
303318
if isinstance(subscription.schema, str)
304319
else subscription.schema
305320
)
306321
replay = self._file_replay_table[(self.dataset, schema)]
307-
logger.info("starting replay %s for %s", replay.name, self.peer)
322+
logger.info(
323+
"starting replay %s for %s (compression=%s)",
324+
replay.name,
325+
self.peer,
326+
self._compression,
327+
)
308328
for chunk in replay:
309-
self.transport.write(chunk)
329+
if compressor is not None:
330+
compressed = cctx.compress(chunk)
331+
if compressed:
332+
self.transport.write(compressed)
333+
else:
334+
self.transport.write(chunk)
310335
await asyncio.sleep(0)
336+
337+
# Flush remaining compressed data
338+
if compressor is not None:
339+
remaining = cctx.flush()
340+
if remaining:
341+
self.transport.write(remaining)
342+
311343
logger.info("replay of %s completed for %s", replay.name, self.peer)

tests/test_live_compression.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from unittest.mock import MagicMock
5+
6+
import pytest
7+
from databento_dbn import Compression
8+
from databento_dbn import Schema
9+
10+
from databento.common.publishers import Dataset
11+
from databento.live.protocol import DatabentoLiveProtocol
12+
from tests.mockliveserver.fixture import MockLiveServerInterface
13+
14+
15+
@pytest.mark.parametrize("compression", [Compression.NONE, Compression.ZSTD])
16+
async def test_protocol_connection_with_compression(
17+
mock_live_server: MockLiveServerInterface,
18+
test_live_api_key: str,
19+
compression: Compression,
20+
) -> None:
21+
"""
22+
Test protocol connection with different compression settings.
23+
"""
24+
# Arrange
25+
transport, protocol = await asyncio.get_event_loop().create_connection(
26+
protocol_factory=lambda: DatabentoLiveProtocol(
27+
api_key=test_live_api_key,
28+
dataset=Dataset.GLBX_MDP3,
29+
compression=compression,
30+
),
31+
host=mock_live_server.host,
32+
port=mock_live_server.port,
33+
)
34+
35+
# Act, Assert
36+
await asyncio.wait_for(protocol.authenticated, timeout=1)
37+
transport.close()
38+
await asyncio.wait_for(protocol.disconnected, timeout=1)
39+
40+
41+
@pytest.mark.parametrize("compression", [Compression.NONE, Compression.ZSTD])
42+
async def test_protocol_streaming_with_compression(
43+
monkeypatch: pytest.MonkeyPatch,
44+
mock_live_server: MockLiveServerInterface,
45+
test_live_api_key: str,
46+
compression: Compression,
47+
) -> None:
48+
"""
49+
Test streaming records with different compression settings.
50+
"""
51+
# Arrange
52+
monkeypatch.setattr(
53+
DatabentoLiveProtocol,
54+
"received_metadata",
55+
metadata_mock := MagicMock(),
56+
)
57+
monkeypatch.setattr(
58+
DatabentoLiveProtocol,
59+
"received_record",
60+
record_mock := MagicMock(),
61+
)
62+
63+
_, protocol = await asyncio.get_event_loop().create_connection(
64+
protocol_factory=lambda: DatabentoLiveProtocol(
65+
api_key=test_live_api_key,
66+
dataset=Dataset.GLBX_MDP3,
67+
compression=compression,
68+
),
69+
host=mock_live_server.host,
70+
port=mock_live_server.port,
71+
)
72+
73+
await asyncio.wait_for(protocol.authenticated, timeout=1)
74+
75+
# Act
76+
protocol.subscribe(
77+
schema=Schema.MBO,
78+
symbols="ESM4",
79+
)
80+
protocol.start()
81+
82+
# Assert
83+
await asyncio.wait_for(protocol.disconnected, timeout=5)
84+
metadata_mock.assert_called()
85+
record_mock.assert_called()

tests/test_live_gateway_messages.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_parse_authentication_request(
9393
dataset=Dataset.GLBX_MDP3,
9494
client="unittest",
9595
),
96-
b"auth=abcd1234|dataset=GLBX.MDP3|encoding=dbn|ts_out=0|client=unittest\n",
96+
b"auth=abcd1234|dataset=GLBX.MDP3|encoding=dbn|ts_out=0|compression=none|client=unittest\n",
9797
),
9898
pytest.param(
9999
AuthenticationRequest(
@@ -103,7 +103,16 @@ def test_parse_authentication_request(
103103
client="unittest",
104104
heartbeat_interval_s=35,
105105
),
106-
b"auth=abcd1234|dataset=XNAS.ITCH|encoding=dbn|ts_out=1|heartbeat_interval_s=35|client=unittest\n",
106+
b"auth=abcd1234|dataset=XNAS.ITCH|encoding=dbn|ts_out=1|compression=none|heartbeat_interval_s=35|client=unittest\n",
107+
),
108+
pytest.param(
109+
AuthenticationRequest(
110+
auth="abc",
111+
dataset=Dataset.OPRA_PILLAR,
112+
compression="zstd",
113+
client="Databento",
114+
),
115+
b"auth=abc|dataset=OPRA.PILLAR|encoding=dbn|ts_out=0|compression=zstd|client=Databento\n",
107116
),
108117
],
109118
)

0 commit comments

Comments
 (0)