Skip to content

Commit 0e2e7e5

Browse files
committed
ArgPrimitiveIndex for zlib
1 parent d71910c commit 0e2e7e5

File tree

4 files changed

+51
-31
lines changed

4 files changed

+51
-31
lines changed

stdlib/src/zlib.rs

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ pub(crate) use zlib::make_module;
22

33
#[pymodule]
44
mod zlib {
5-
use crate::common::lock::PyMutex;
65
use crate::vm::{
76
builtins::{PyBaseExceptionRef, PyBytes, PyBytesRef, PyIntRef, PyTypeRef},
8-
function::{ArgBytesLike, OptionalArg, OptionalOption},
7+
common::lock::PyMutex,
8+
function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg, OptionalOption},
99
PyPayload, PyResult, VirtualMachine,
1010
};
1111
use adler32::RollingAdler32 as Adler32;
@@ -233,9 +233,9 @@ mod zlib {
233233
#[pyarg(positional)]
234234
data: ArgBytesLike,
235235
#[pyarg(any, optional)]
236-
wbits: OptionalArg<i8>,
236+
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
237237
#[pyarg(any, optional)]
238-
bufsize: OptionalArg<usize>,
238+
bufsize: OptionalArg<ArgPrimitiveIndex<usize>>,
239239
}
240240

241241
/// Returns a bytes object containing the uncompressed data.
@@ -245,9 +245,9 @@ mod zlib {
245245
let wbits = arg.wbits;
246246
let bufsize = arg.bufsize;
247247
data.with_ref(|data| {
248-
let bufsize = bufsize.unwrap_or(DEF_BUF_SIZE);
248+
let bufsize = bufsize.into_primitive().unwrap_or(DEF_BUF_SIZE);
249249

250-
let mut d = header_from_wbits(wbits, vm)?.decompress();
250+
let mut d = header_from_wbits(wbits.into_primitive(), vm)?.decompress();
251251

252252
_decompress(data, &mut d, bufsize, None, false, vm).and_then(|(buf, stream_end)| {
253253
if stream_end {
@@ -265,7 +265,7 @@ mod zlib {
265265
#[pyfunction]
266266
fn decompressobj(args: DecompressobjArgs, vm: &VirtualMachine) -> PyResult<PyDecompress> {
267267
#[allow(unused_mut)]
268-
let mut decompress = header_from_wbits(args.wbits, vm)?.decompress();
268+
let mut decompress = header_from_wbits(args.wbits.into_primitive(), vm)?.decompress();
269269
#[cfg(feature = "zlib")]
270270
if let OptionalArg::Present(dict) = args.zdict {
271271
dict.with_ref(|d| decompress.set_dictionary(d).unwrap());
@@ -325,11 +325,8 @@ mod zlib {
325325

326326
#[pymethod]
327327
fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
328-
let max_length = if args.max_length == 0 {
329-
None
330-
} else {
331-
Some(args.max_length)
332-
};
328+
let max_length = args.max_length.value;
329+
let max_length = (max_length != 0).then_some(max_length);
333330
let data = args.data.borrow_buf();
334331
let data = &*data;
335332

@@ -362,12 +359,18 @@ mod zlib {
362359
}
363360

364361
#[pymethod]
365-
fn flush(&self, length: OptionalArg<isize>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
362+
fn flush(&self, length: OptionalArg<ArgSize>, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
366363
let length = match length {
367-
OptionalArg::Present(l) if l <= 0 => {
368-
return Err(vm.new_value_error("length must be greater than zero".to_owned()));
364+
OptionalArg::Present(l) => {
365+
let l: isize = l.into();
366+
if l <= 0 {
367+
return Err(
368+
vm.new_value_error("length must be greater than zero".to_owned())
369+
);
370+
} else {
371+
l as usize
372+
}
369373
}
370-
OptionalArg::Present(l) => l as usize,
371374
OptionalArg::Missing => DEF_BUF_SIZE,
372375
};
373376

@@ -396,14 +399,17 @@ mod zlib {
396399
struct DecompressArgs {
397400
#[pyarg(positional)]
398401
data: ArgBytesLike,
399-
#[pyarg(any, default = "0")]
400-
max_length: usize,
402+
#[pyarg(
403+
any,
404+
default = "rustpython_vm::function::ArgPrimitiveIndex { value: 0 }"
405+
)]
406+
max_length: ArgPrimitiveIndex<usize>,
401407
}
402408

403409
#[derive(FromArgs)]
404410
struct DecompressobjArgs {
405411
#[pyarg(any, optional)]
406-
wbits: OptionalArg<i8>,
412+
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
407413
#[cfg(feature = "zlib")]
408414
#[pyarg(any, optional)]
409415
zdict: OptionalArg<ArgBytesLike>,
@@ -414,7 +420,7 @@ mod zlib {
414420
level: OptionalArg<i32>,
415421
// only DEFLATED is valid right now, it's w/e
416422
_method: OptionalArg<i32>,
417-
wbits: OptionalArg<i8>,
423+
wbits: OptionalArg<ArgPrimitiveIndex<i8>>,
418424
// these aren't used.
419425
_mem_level: OptionalArg<i32>, // this is memLevel in CPython
420426
_strategy: OptionalArg<i32>,
@@ -423,7 +429,7 @@ mod zlib {
423429
) -> PyResult<PyCompress> {
424430
let level = compression_from_int(level.into_option())
425431
.ok_or_else(|| vm.new_value_error("invalid initialization option".to_owned()))?;
426-
let compress = header_from_wbits(wbits, vm)?.compress(level);
432+
let compress = header_from_wbits(wbits.into_primitive(), vm)?.compress(level);
427433
Ok(PyCompress {
428434
inner: PyMutex::new(CompressInner {
429435
compress,

vm/src/builtins/str.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
},
1414
convert::{IntoPyException, ToPyException, ToPyObject, ToPyResult},
1515
format::{format, format_map},
16-
function::{ArgSize, ArgIterable, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue},
16+
function::{ArgIterable, ArgSize, FuncArgs, OptionalArg, OptionalOption, PyComparisonValue},
1717
intern::PyInterned,
1818
protocol::{PyIterReturn, PyMappingMethods, PyNumberMethods, PySequenceMethods},
1919
sequence::SequenceExt,

vm/src/function/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub use builtin::{IntoPyNativeFunc, OwnedParam, PyNativeFunc, RefParam};
1717
pub use either::Either;
1818
pub use getset::PySetterValue;
1919
pub(super) use getset::{IntoPyGetterFunc, IntoPySetterFunc, PyGetterFunc, PySetterFunc};
20-
pub use number::{ArgIndex, ArgIntoBool, ArgIntoComplex, ArgIntoFloat, ArgSize};
20+
pub use number::{ArgIndex, ArgIntoBool, ArgIntoComplex, ArgIntoFloat, ArgPrimitiveIndex, ArgSize};
2121
pub use protocol::{ArgCallable, ArgIterable, ArgMapping, ArgSequence};
2222

2323
use crate::{builtins::PyStr, convert::TryFromBorrowedObject, PyObject, PyResult, VirtualMachine};

vm/src/function/number.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
use super::argument::OptionalArg;
12
use crate::{builtins::PyIntRef, AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine};
3+
use num_bigint::BigInt;
24
use num_complex::Complex64;
5+
use num_traits::PrimInt;
36
use std::ops::Deref;
47

58
/// A Python complex-like object.
@@ -157,28 +160,39 @@ impl TryFromObject for ArgIndex {
157160

158161
#[derive(Debug)]
159162
#[repr(transparent)]
160-
pub struct ArgSize {
161-
value: isize,
163+
pub struct ArgPrimitiveIndex<T> {
164+
pub value: T,
162165
}
163166

164-
impl From<ArgSize> for isize {
165-
fn from(arg: ArgSize) -> Self {
166-
arg.value
167+
impl<T> OptionalArg<ArgPrimitiveIndex<T>> {
168+
pub fn into_primitive(self) -> OptionalArg<T> {
169+
self.map(|x| x.value)
167170
}
168171
}
169172

170-
impl Deref for ArgSize {
171-
type Target = isize;
173+
impl<T> Deref for ArgPrimitiveIndex<T> {
174+
type Target = T;
172175

173176
fn deref(&self) -> &Self::Target {
174177
&self.value
175178
}
176179
}
177180

178-
impl TryFromObject for ArgSize {
181+
impl<T> TryFromObject for ArgPrimitiveIndex<T>
182+
where
183+
T: PrimInt + for<'a> TryFrom<&'a BigInt>,
184+
{
179185
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
180186
Ok(Self {
181187
value: obj.try_index(vm)?.try_to_primitive(vm)?,
182188
})
183189
}
184190
}
191+
192+
pub type ArgSize = ArgPrimitiveIndex<isize>;
193+
194+
impl From<ArgSize> for isize {
195+
fn from(arg: ArgSize) -> Self {
196+
arg.value
197+
}
198+
}

0 commit comments

Comments
 (0)