Skip to content

Commit e96dd96

Browse files
committed
Refactor zlib and add wbits to zlib.compress()
1 parent 8ff947e commit e96dd96

File tree

2 files changed

+115
-76
lines changed

2 files changed

+115
-76
lines changed

extra_tests/snippets/stdlib_zlib.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,9 @@
4848
b"x\xda\xf3\xc9/J\xcdU\xc8,(.\xcdUH\xc9\xcf\xc9/R(\xce,QH\xccM-\x01\x00\x83\xd5\t\xc5",
4949
]
5050

51-
for level, text in enumerate(compressed_lorem_list):
52-
assert zlib.compress(lorem, level) == text
51+
for level, expected in enumerate(compressed_lorem_list):
52+
actual = zlib.compress(lorem, level)
53+
assert actual == expected
5354

5455
# default level
5556
assert zlib.compress(lorem) == zlib.compress(lorem, -1) == zlib.compress(lorem, 6)

stdlib/src/zlib.rs

Lines changed: 112 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod zlib {
55
use crate::vm::{
66
builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyIntRef, PyTypeRef},
77
common::lock::PyMutex,
8-
function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg, OptionalOption},
8+
function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg},
99
PyPayload, PyResult, VirtualMachine,
1010
};
1111
use adler32::RollingAdler32 as Adler32;
@@ -47,7 +47,7 @@ mod zlib {
4747

4848
// copied from zlibmodule.c (commit 530f506ac91338)
4949
#[pyattr]
50-
const MAX_WBITS: u8 = 15;
50+
const MAX_WBITS: i8 = 15;
5151
#[pyattr]
5252
const DEF_BUF_SIZE: usize = 16 * 1024;
5353
#[pyattr]
@@ -78,8 +78,9 @@ mod zlib {
7878
crate::binascii::crc32(data, begin_state)
7979
}
8080

81-
fn compression_from_int(level: Option<i32>) -> Option<Compression> {
82-
match level.unwrap_or(Z_DEFAULT_COMPRESSION) {
81+
// TODO: rewrite with TryFromBorrowedObject
82+
fn compression_from_int(level: i32) -> Option<Compression> {
83+
match level {
8384
Z_DEFAULT_COMPRESSION => Some(Compression::default()),
8485
valid_level @ Z_NO_COMPRESSION..=Z_BEST_COMPRESSION => {
8586
Some(Compression::new(valid_level as u32))
@@ -92,23 +93,33 @@ mod zlib {
9293
struct PyFuncCompressArgs {
9394
#[pyarg(positional)]
9495
data: ArgBytesLike,
95-
#[pyarg(any, optional)]
96-
level: OptionalOption<i32>,
96+
#[pyarg(any, default = "Z_DEFAULT_COMPRESSION")]
97+
level: i32,
98+
#[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")]
99+
wbits: ArgPrimitiveIndex<i8>,
97100
}
98101

99102
/// Returns a bytes object containing compressed data.
100103
#[pyfunction]
101104
fn compress(args: PyFuncCompressArgs, vm: &VirtualMachine) -> PyResult<PyBytesRef> {
102-
let data = args.data;
103-
let level = args.level;
105+
let PyFuncCompressArgs {
106+
data,
107+
level,
108+
ref wbits,
109+
} = args;
104110

105-
let compression = compression_from_int(level.flatten())
111+
let level = compression_from_int(level)
106112
.ok_or_else(|| new_zlib_error("Bad compression level", vm))?;
107113

108-
let mut encoder = ZlibEncoder::new(Vec::new(), compression);
109-
data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap());
110-
let encoded_bytes = encoder.finish().unwrap();
111-
114+
let encoded_bytes = if args.wbits.value == MAX_WBITS {
115+
let mut encoder = ZlibEncoder::new(Vec::new(), level);
116+
data.with_ref(|input_bytes| encoder.write_all(input_bytes).unwrap());
117+
encoder.finish().unwrap()
118+
} else {
119+
let mut inner = CompressInner::new(InitOptions::new(wbits.value, vm)?.compress(level));
120+
data.with_ref(|input_bytes| inner.compress(input_bytes, vm))?;
121+
inner.flush(vm)?
122+
};
112123
Ok(vm.ctx.new_bytes(encoded_bytes))
113124
}
114125

@@ -125,6 +136,21 @@ mod zlib {
125136
}
126137

127138
impl InitOptions {
139+
fn new(wbits: i8, vm: &VirtualMachine) -> PyResult<InitOptions> {
140+
let header = wbits > 0;
141+
let wbits = wbits.unsigned_abs();
142+
match wbits {
143+
9..=15 => Ok(InitOptions::Standard {
144+
header,
145+
#[cfg(feature = "zlib")]
146+
wbits,
147+
}),
148+
#[cfg(feature = "zlib")]
149+
25..=31 => Ok(InitOptions::Gzip { wbits: wbits - 16 }),
150+
_ => Err(vm.new_value_error("Invalid initialization option".to_owned())),
151+
}
152+
}
153+
128154
fn decompress(self) -> Decompress {
129155
match self {
130156
#[cfg(not(feature = "zlib"))]
@@ -149,22 +175,6 @@ mod zlib {
149175
}
150176
}
151177

152-
fn header_from_wbits(wbits: OptionalArg<i8>, vm: &VirtualMachine) -> PyResult<InitOptions> {
153-
let wbits = wbits.unwrap_or(MAX_WBITS as i8);
154-
let header = wbits > 0;
155-
let wbits = wbits.unsigned_abs();
156-
match wbits {
157-
9..=15 => Ok(InitOptions::Standard {
158-
header,
159-
#[cfg(feature = "zlib")]
160-
wbits,
161-
}),
162-
#[cfg(feature = "zlib")]
163-
25..=31 => Ok(InitOptions::Gzip { wbits: wbits - 16 }),
164-
_ => Err(vm.new_value_error("Invalid initialization option".to_owned())),
165-
}
166-
}
167-
168178
fn _decompress(
169179
mut data: &[u8],
170180
d: &mut Decompress,
@@ -232,43 +242,55 @@ mod zlib {
232242
struct PyFuncDecompressArgs {
233243
#[pyarg(positional)]
234244
data: ArgBytesLike,
235-
#[pyarg(any, optional)]
236-
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
237-
#[pyarg(any, optional)]
238-
bufsize: OptionalArg<ArgPrimitiveIndex<usize>>,
245+
#[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")]
246+
wbits: ArgPrimitiveIndex<i8>,
247+
#[pyarg(any, default = "ArgPrimitiveIndex { value: DEF_BUF_SIZE }")]
248+
bufsize: ArgPrimitiveIndex<usize>,
239249
}
240250

241251
/// Returns a bytes object containing the uncompressed data.
242252
#[pyfunction]
243-
fn decompress(arg: PyFuncDecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
244-
let data = arg.data;
245-
let wbits = arg.wbits;
246-
let bufsize = arg.bufsize;
253+
fn decompress(args: PyFuncDecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
254+
let PyFuncDecompressArgs {
255+
data,
256+
wbits,
257+
bufsize,
258+
} = args;
247259
data.with_ref(|data| {
248-
let bufsize = bufsize.into_primitive().unwrap_or(DEF_BUF_SIZE);
249-
250-
let mut d = header_from_wbits(wbits.into_primitive(), vm)?.decompress();
260+
let mut d = InitOptions::new(wbits.value, vm)?.decompress();
251261

252-
_decompress(data, &mut d, bufsize, None, false, vm).and_then(|(buf, stream_end)| {
253-
if stream_end {
254-
Ok(buf)
255-
} else {
256-
Err(new_zlib_error(
257-
"Error -5 while decompressing data: incomplete or truncated stream",
258-
vm,
259-
))
260-
}
261-
})
262+
_decompress(data, &mut d, bufsize.value, None, false, vm).and_then(
263+
|(buf, stream_end)| {
264+
if stream_end {
265+
Ok(buf)
266+
} else {
267+
Err(new_zlib_error(
268+
"Error -5 while decompressing data: incomplete or truncated stream",
269+
vm,
270+
))
271+
}
272+
},
273+
)
262274
})
263275
}
264276

277+
#[derive(FromArgs)]
278+
struct DecompressobjArgs {
279+
#[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")]
280+
wbits: ArgPrimitiveIndex<i8>,
281+
#[cfg(feature = "zlib")]
282+
#[pyarg(any, optional)]
283+
_zdict: OptionalArg<ArgBytesLike>,
284+
}
285+
265286
#[pyfunction]
266287
fn decompressobj(args: DecompressobjArgs, vm: &VirtualMachine) -> PyResult<PyDecompress> {
267288
#[allow(unused_mut)]
268-
let mut decompress = header_from_wbits(args.wbits.into_primitive(), vm)?.decompress();
289+
let mut decompress = InitOptions::new(args.wbits.value, vm)?.decompress();
269290
#[cfg(feature = "zlib")]
270-
if let OptionalArg::Present(dict) = args.zdict {
271-
dict.with_ref(|d| decompress.set_dictionary(d).unwrap());
291+
if let OptionalArg::Present(_dict) = args._zdict {
292+
// FIXME: always fails
293+
// dict.with_ref(|d| decompress.set_dictionary(d));
272294
}
273295
Ok(PyDecompress {
274296
decompress: PyMutex::new(decompress),
@@ -407,34 +429,44 @@ mod zlib {
407429
}
408430

409431
#[derive(FromArgs)]
410-
struct DecompressobjArgs {
411-
#[pyarg(any, optional)]
412-
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
432+
#[allow(dead_code)] // FIXME: use args
433+
struct CompressobjArgs {
434+
#[pyarg(any, default = "Z_DEFAULT_COMPRESSION")]
435+
level: i32,
436+
// only DEFLATED is valid right now, it's w/e
437+
#[pyarg(any, default = "DEFLATED")]
438+
_method: i32,
439+
#[pyarg(any, default = "ArgPrimitiveIndex { value: MAX_WBITS }")]
440+
wbits: ArgPrimitiveIndex<i8>,
441+
#[pyarg(any, name = "_memLevel", default = "DEF_MEM_LEVEL")]
442+
_mem_level: u8,
443+
#[cfg(feature = "zlib")]
444+
#[pyarg(any, default = "Z_DEFAULT_STRATEGY")]
445+
_strategy: i32,
413446
#[cfg(feature = "zlib")]
414447
#[pyarg(any, optional)]
415-
zdict: OptionalArg<ArgBytesLike>,
448+
zdict: Option<ArgBytesLike>,
416449
}
417450

418451
#[pyfunction]
419-
fn compressobj(
420-
level: OptionalArg<i32>,
421-
// only DEFLATED is valid right now, it's w/e
422-
_method: OptionalArg<i32>,
423-
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
424-
// these aren't used.
425-
_mem_level: OptionalArg<i32>, // this is memLevel in CPython
426-
_strategy: OptionalArg<i32>,
427-
_zdict: OptionalArg<ArgBytesLike>,
428-
vm: &VirtualMachine,
429-
) -> PyResult<PyCompress> {
430-
let level = compression_from_int(level.into_option())
452+
fn compressobj(args: CompressobjArgs, vm: &VirtualMachine) -> PyResult<PyCompress> {
453+
let CompressobjArgs {
454+
level,
455+
wbits,
456+
#[cfg(feature = "zlib")]
457+
zdict,
458+
..
459+
} = args;
460+
let level = compression_from_int(level)
431461
.ok_or_else(|| vm.new_value_error("invalid initialization option".to_owned()))?;
432-
let compress = header_from_wbits(wbits.into_primitive(), vm)?.compress(level);
462+
#[allow(unused_mut)]
463+
let mut compress = InitOptions::new(wbits.value, vm)?.compress(level);
464+
#[cfg(feature = "zlib")]
465+
if let Some(zdict) = zdict {
466+
zdict.with_ref(|zdict| compress.set_dictionary(zdict).unwrap());
467+
}
433468
Ok(PyCompress {
434-
inner: PyMutex::new(CompressInner {
435-
compress,
436-
unconsumed: Vec::new(),
437-
}),
469+
inner: PyMutex::new(CompressInner::new(compress)),
438470
})
439471
}
440472

@@ -477,6 +509,12 @@ mod zlib {
477509
const CHUNKSIZE: usize = u32::MAX as usize;
478510

479511
impl CompressInner {
512+
fn new(compress: Compress) -> Self {
513+
Self {
514+
compress,
515+
unconsumed: Vec::new(),
516+
}
517+
}
480518
fn compress(&mut self, data: &[u8], vm: &VirtualMachine) -> PyResult<Vec<u8>> {
481519
let orig_in = self.compress.total_in() as usize;
482520
let mut cur_in = 0;

0 commit comments

Comments
 (0)