Skip to content

Commit fba12be

Browse files
committed
bugfix for zstd decompressobj reuse
Fixes #3538
1 parent b5addb6 commit fba12be

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

httpx/_decoders.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __init__(self) -> None:
176176

177177
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
178178
self.seen_data = False
179+
self.seen_eof = False
179180

180181
def decode(self, data: bytes) -> bytes:
181182
assert zstandard is not None
@@ -187,6 +188,11 @@ def decode(self, data: bytes) -> bytes:
187188
unused_data = self.decompressor.unused_data
188189
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
189190
output.write(self.decompressor.decompress(unused_data))
191+
# If the decompressor reached EOF, create a new one for the next call
192+
# since zstd decompressors cannot be reused after EOF
193+
if self.decompressor.eof:
194+
self.seen_eof = True
195+
self.decompressor = zstandard.ZstdDecompressor().decompressobj()
190196
except zstandard.ZstdError as exc:
191197
raise DecodingError(str(exc)) from exc
192198
return output.getvalue()
@@ -195,7 +201,7 @@ def flush(self) -> bytes:
195201
if not self.seen_data:
196202
return b""
197203
ret = self.decompressor.flush() # note: this is a no-op
198-
if not self.decompressor.eof:
204+
if not self.decompressor.eof and not self.seen_eof:
199205
raise DecodingError("Zstandard data is incomplete") # pragma: no cover
200206
return bytes(ret)
201207

tests/test_decoders.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,26 @@ def test_zstd_multiframe():
141141
assert response.content == b"foobar"
142142

143143

144+
def test_zstd_streaming_multiple_frames():
145+
body1 = b"test 123 "
146+
body2 = b"another frame"
147+
148+
# Create two separate complete frames
149+
frame1 = zstd.compress(body1)
150+
frame2 = zstd.compress(body2)
151+
152+
# Create an iterator that yields frames separately
153+
def content_iterator() -> typing.Iterator[bytes]:
154+
yield frame1
155+
yield frame2
156+
157+
headers = [(b"Content-Encoding", b"zstd")]
158+
response = httpx.Response(200, headers=headers, content=content_iterator())
159+
response.read()
160+
161+
assert response.content == body1 + body2
162+
163+
144164
def test_multi():
145165
body = b"test 123"
146166

0 commit comments

Comments
 (0)